1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5 MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6 REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::num::NonZeroU8;
27use std::simd::Simd;
28use std::sync::Arc as StdArc;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::{fmt, iter};
31use tracing::{debug, warn};
32use wgpu::{
33 AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35 BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36 ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37 DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38 PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39 RequestDeviceError, ShaderModule, ShaderRuntimeChecks, ShaderStages,
40};
41
42#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45 #[error("Too many records: {0}")]
47 TooManyRecords(usize),
48 #[error("Proof creation failed previously and the device is now considered broken")]
50 DeviceBroken,
51 #[error("Failed to map buffer: {0}")]
53 BufferMapping(#[from] BufferAsyncError),
54 #[error("Poll error: {0}")]
56 DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60 proofs: &'a ProofsHost,
61 proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65 fn drop(&mut self) {
66 self.proofs_host.unmap();
67 }
68}
69
70#[derive(Clone)]
72pub struct Device {
73 id: u32,
74 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75 adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("Device")
81 .field("id", &self.id)
82 .field("name", &self.adapter_info.name)
83 .field("device_type", &self.adapter_info.device_type)
84 .field("driver", &self.adapter_info.driver)
85 .field("driver_info", &self.adapter_info.driver_info)
86 .field("backend", &self.adapter_info.backend)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Device {
92 pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94 where
95 NOQ: Fn(DeviceType) -> NonZeroU8,
96 {
97 let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98 let instance = Instance::new(InstanceDescriptor {
99 backends,
100 flags: if cfg!(debug_assertions) {
101 InstanceFlags::debugging().with_env()
102 } else {
103 InstanceFlags::from_env_or_default()
104 },
105 memory_budget_thresholds: MemoryBudgetThresholds::default(),
106 backend_options: BackendOptions::from_env_or_default(),
107 display: None,
108 });
109
110 let adapters = instance.enumerate_adapters(backends).await;
111 let number_of_queues = &number_of_queues;
112
113 adapters
114 .into_iter()
115 .zip(0..)
116 .map(|(adapter, id)| async move {
117 let adapter_info = adapter.get_info();
118
119 let (shader, required_features, required_limits) =
120 if let Some((shader, required_features, required_limits)) =
121 select_shader_features_limits(&adapter)
122 {
123 debug!(
124 %id,
125 adapter_info = ?adapter_info,
126 "Compatible adapter found"
127 );
128
129 (shader, required_features, required_limits)
130 } else {
131 debug!(
132 %id,
133 adapter_info = ?adapter_info,
134 "Incompatible adapter found"
135 );
136
137 return None;
138 };
139
140 let devices = iter::repeat_with(|| async {
143 let (device, queue) = adapter
144 .request_device(&DeviceDescriptor {
145 label: None,
146 required_features,
147 required_limits: required_limits.clone(),
148 ..DeviceDescriptor::default()
149 })
150 .await
151 .inspect_err(|error| {
152 warn!(%id, ?adapter_info, %error, "Failed to request the device");
153 })?;
154 let module = if cfg!(debug_assertions) {
155 device.create_shader_module(shader.clone())
156 } else {
157 unsafe {
161 device.create_shader_module_trusted(
162 shader.clone(),
163 ShaderRuntimeChecks::unchecked(),
164 )
165 }
166 };
167
168 Ok::<_, RequestDeviceError>((device, queue, module))
169 })
170 .take(usize::from(
171 number_of_queues(adapter_info.device_type).get(),
172 ))
173 .collect::<FuturesOrdered<_>>()
174 .try_collect::<Vec<_>>()
175 .await
176 .ok()?;
177
178 Some(Self {
179 id,
180 devices,
181 adapter_info,
182 })
183 })
184 .collect::<FuturesOrdered<_>>()
185 .filter_map(|device| async move { device })
186 .collect()
187 .await
188 }
189
190 pub fn id(&self) -> u32 {
192 self.id
193 }
194
195 pub fn name(&self) -> &str {
197 &self.adapter_info.name
198 }
199
200 pub fn device_type(&self) -> DeviceType {
202 self.adapter_info.device_type
203 }
204
205 pub fn driver(&self) -> &str {
207 &self.adapter_info.driver
208 }
209
210 pub fn driver_info(&self) -> &str {
212 &self.adapter_info.driver_info
213 }
214
215 pub fn backend(&self) -> Backend {
217 self.adapter_info.backend
218 }
219
220 pub fn instantiate(
221 &self,
222 erasure_coding: ErasureCoding,
223 global_mutex: StdArc<AsyncMutex<()>>,
224 ) -> anyhow::Result<GpuRecordsEncoder> {
225 GpuRecordsEncoder::new(
226 self.id,
227 self.devices.clone(),
228 self.adapter_info.clone(),
229 erasure_coding,
230 global_mutex,
231 )
232 }
233}
234
235pub struct GpuRecordsEncoder {
236 id: u32,
237 instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
238 thread_pool: ThreadPool,
239 adapter_info: AdapterInfo,
240 erasure_coding: ErasureCoding,
241 global_mutex: StdArc<AsyncMutex<()>>,
242}
243
244impl fmt::Debug for GpuRecordsEncoder {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 f.debug_struct("GpuRecordsEncoder")
247 .field("id", &self.id)
248 .field("name", &self.adapter_info.name)
249 .field("device_type", &self.adapter_info.device_type)
250 .field("driver", &self.adapter_info.driver)
251 .field("driver_info", &self.adapter_info.driver_info)
252 .field("backend", &self.adapter_info.backend)
253 .finish_non_exhaustive()
254 }
255}
256
257impl RecordsEncoder for GpuRecordsEncoder {
258 fn encode_records(
260 &mut self,
261 sector_id: &SectorId,
262 records: &mut [Record],
263 abort_early: &AtomicBool,
264 ) -> anyhow::Result<SectorContentsMap> {
265 let mut sector_contents_map = SectorContentsMap::new(
266 u16::try_from(records.len())
267 .map_err(|_error| RecordEncodingError::TooManyRecords(records.len()))?,
268 );
269
270 let maybe_error = self.thread_pool.install(|| {
271 records
272 .par_iter_mut()
273 .zip(sector_contents_map.iter_record_chunks_used_mut())
274 .enumerate()
275 .find_map_any(|(piece_offset, (record, record_chunks_used))| {
276 self.global_mutex.lock_blocking();
278
279 let mut parity_record_chunks = Record::new_boxed();
280
281 self.erasure_coding
284 .extend(record.iter(), parity_record_chunks.iter_mut())
285 .expect("Statically guaranteed valid inputs; qed");
286
287 if abort_early.load(Ordering::Relaxed) {
288 return None;
289 }
290 let seed =
291 sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
292 let thread_index = rayon::current_thread_index().unwrap_or_default();
293 let mut encoder_instance = self.instances[thread_index]
294 .try_lock()
295 .expect("1:1 mapping between threads and devices; qed");
296 let proofs = match encoder_instance.create_proofs(&seed) {
297 Ok(proofs) => proofs,
298 Err(error) => {
299 return Some(error);
300 }
301 };
302 let proofs = proofs.proofs;
303
304 *record_chunks_used = proofs.found_proofs;
305
306 let mut num_found_proofs = 0_usize;
308 for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
309 .array_chunks::<{ u8::BITS as usize }>()
310 .zip(record_chunks_used)
311 {
312 for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
313 if num_found_proofs == Record::NUM_CHUNKS {
314 *found_proofs &=
316 u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
317 break;
318 }
319 if (*found_proofs & (1 << proof_offset)) != 0 {
320 let record_chunk = if s_bucket < Record::NUM_CHUNKS {
321 record[s_bucket]
322 } else {
323 parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
324 };
325
326 record[num_found_proofs] = (Simd::from(record_chunk)
327 ^ Simd::from(*proofs.proofs[s_bucket].hash()))
328 .to_array();
329 num_found_proofs += 1;
330 }
331 }
332 }
333
334 None
335 })
336 });
337
338 if let Some(error) = maybe_error {
339 return Err(error.into());
340 }
341
342 Ok(sector_contents_map)
343 }
344}
345
346impl GpuRecordsEncoder {
347 fn new(
348 id: u32,
349 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
350 adapter_info: AdapterInfo,
351 erasure_coding: ErasureCoding,
352 global_mutex: StdArc<AsyncMutex<()>>,
353 ) -> anyhow::Result<Self> {
354 let thread_pool = ThreadPoolBuilder::new()
355 .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
356 .num_threads(devices.len())
357 .build()?;
358
359 Ok(Self {
360 id,
361 instances: devices
362 .into_iter()
363 .map(|(device, queue, module)| {
364 Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
365 })
366 .collect(),
367 thread_pool,
368 adapter_info,
369 erasure_coding,
370 global_mutex,
371 })
372 }
373}
374
375struct GpuRecordsEncoderInstance {
376 device: wgpu::Device,
377 queue: Queue,
378 mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
379 tainted: bool,
380 initial_state_host: Buffer,
381 initial_state_gpu: Buffer,
382 proofs_host: Buffer,
383 proofs_gpu: Buffer,
384 bind_group_compute_f1: BindGroup,
385 compute_pipeline_compute_f1: ComputePipeline,
386 bind_group_sort_buckets_a: BindGroup,
387 compute_pipeline_sort_buckets_a: ComputePipeline,
388 bind_group_sort_buckets_b: BindGroup,
389 compute_pipeline_sort_buckets_b: ComputePipeline,
390 bind_group_find_matches_and_compute_f2: BindGroup,
391 compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
392 bind_group_find_matches_and_compute_f3: BindGroup,
393 compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
394 bind_group_find_matches_and_compute_f4: BindGroup,
395 compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
396 bind_group_find_matches_and_compute_f5: BindGroup,
397 compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
398 bind_group_find_matches_and_compute_f6: BindGroup,
399 compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
400 bind_group_find_matches_and_compute_f7: BindGroup,
401 compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
402 bind_group_find_proofs: BindGroup,
403 compute_pipeline_find_proofs: ComputePipeline,
404}
405
406impl GpuRecordsEncoderInstance {
407 fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
408 let initial_state_host = device.create_buffer(&BufferDescriptor {
409 label: Some("initial_state_host"),
410 size: size_of::<ChaCha8Block>() as BufferAddress,
411 usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
412 mapped_at_creation: true,
413 });
414
415 let initial_state_gpu = device.create_buffer(&BufferDescriptor {
416 label: Some("initial_state_gpu"),
417 size: initial_state_host.size(),
418 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
419 mapped_at_creation: false,
420 });
421
422 let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
423 let table_6_proof_targets_sizes_gpu_buffer_size =
424 size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
425 let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
429 label: Some("bucket_sizes_gpu"),
430 size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
431 usage: BufferUsages::STORAGE,
432 mapped_at_creation: false,
433 });
434 let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
436
437 let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
438 label: Some("buckets_a_gpu"),
439 size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
440 usage: BufferUsages::STORAGE,
441 mapped_at_creation: false,
442 });
443
444 let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
445 label: Some("buckets_b_gpu"),
446 size: buckets_a_gpu.size(),
447 usage: BufferUsages::STORAGE,
448 mapped_at_creation: false,
449 });
450
451 let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
452 label: Some("positions_f2_gpu"),
453 size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
454 as BufferAddress,
455 usage: BufferUsages::STORAGE,
456 mapped_at_creation: false,
457 });
458
459 let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
460 label: Some("positions_f3_gpu"),
461 size: positions_f2_gpu.size(),
462 usage: BufferUsages::STORAGE,
463 mapped_at_creation: false,
464 });
465
466 let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
467 label: Some("positions_f4_gpu"),
468 size: positions_f2_gpu.size(),
469 usage: BufferUsages::STORAGE,
470 mapped_at_creation: false,
471 });
472
473 let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
474 label: Some("positions_f5_gpu"),
475 size: positions_f2_gpu.size(),
476 usage: BufferUsages::STORAGE,
477 mapped_at_creation: false,
478 });
479
480 let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
481 label: Some("positions_f6_gpu"),
482 size: positions_f2_gpu.size(),
483 usage: BufferUsages::STORAGE,
484 mapped_at_creation: false,
485 });
486
487 let metadatas_gpu_buffer_size =
488 size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
489 let table_6_proof_targets_gpu_buffer_size = size_of::<
490 [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
491 >() as BufferAddress;
492 let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
493 label: Some("metadatas_a_gpu"),
494 size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
495 usage: BufferUsages::STORAGE,
496 mapped_at_creation: false,
497 });
498 let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
500
501 let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
502 label: Some("metadatas_b_gpu"),
503 size: metadatas_gpu_buffer_size,
504 usage: BufferUsages::STORAGE,
505 mapped_at_creation: false,
506 });
507
508 let proofs_host = device.create_buffer(&BufferDescriptor {
509 label: Some("proofs_host"),
510 size: size_of::<ProofsHost>() as BufferAddress,
511 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
512 mapped_at_creation: false,
513 });
514
515 let proofs_gpu = device.create_buffer(&BufferDescriptor {
516 label: Some("proofs_gpu"),
517 size: proofs_host.size(),
518 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
519 mapped_at_creation: false,
520 });
521
522 let (bind_group_compute_f1, compute_pipeline_compute_f1) =
523 bind_group_and_pipeline_compute_f1(
524 &device,
525 &module,
526 &initial_state_gpu,
527 &bucket_sizes_gpu,
528 &buckets_a_gpu,
529 );
530 let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
531 bind_group_and_pipeline_sort_buckets(
532 &device,
533 &module,
534 &bucket_sizes_gpu,
535 &buckets_a_gpu,
536 );
537
538 let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
539 bind_group_and_pipeline_sort_buckets(
540 &device,
541 &module,
542 &bucket_sizes_gpu,
543 &buckets_b_gpu,
544 );
545
546 let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
547 bind_group_and_pipeline_find_matches_and_compute_f2(
548 &device,
549 &module,
550 &buckets_a_gpu,
551 &bucket_sizes_gpu,
552 &buckets_b_gpu,
553 &positions_f2_gpu,
554 &metadatas_b_gpu,
555 );
556
557 let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
558 bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
559 &device,
560 &module,
561 &buckets_b_gpu,
562 &metadatas_b_gpu,
563 &bucket_sizes_gpu,
564 &buckets_a_gpu,
565 &positions_f3_gpu,
566 &metadatas_a_gpu,
567 );
568
569 let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
570 bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
571 &device,
572 &module,
573 &buckets_a_gpu,
574 &metadatas_a_gpu,
575 &bucket_sizes_gpu,
576 &buckets_b_gpu,
577 &positions_f4_gpu,
578 &metadatas_b_gpu,
579 );
580
581 let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
582 bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
583 &device,
584 &module,
585 &buckets_b_gpu,
586 &metadatas_b_gpu,
587 &bucket_sizes_gpu,
588 &buckets_a_gpu,
589 &positions_f5_gpu,
590 &metadatas_a_gpu,
591 );
592
593 let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
594 bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
595 &device,
596 &module,
597 &buckets_a_gpu,
598 &metadatas_a_gpu,
599 &bucket_sizes_gpu,
600 &buckets_b_gpu,
601 &positions_f6_gpu,
602 &metadatas_b_gpu,
603 );
604
605 let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
606 bind_group_and_pipeline_find_matches_and_compute_f7(
607 &device,
608 &module,
609 &buckets_b_gpu,
610 &metadatas_b_gpu,
611 &table_6_proof_targets_sizes_gpu,
612 &table_6_proof_targets_gpu,
613 );
614
615 let (bind_group_find_proofs, compute_pipeline_find_proofs) =
616 bind_group_and_pipeline_find_proofs(
617 &device,
618 &module,
619 &positions_f2_gpu,
620 &positions_f3_gpu,
621 &positions_f4_gpu,
622 &positions_f5_gpu,
623 &positions_f6_gpu,
624 &table_6_proof_targets_sizes_gpu,
625 &table_6_proof_targets_gpu,
626 &proofs_gpu,
627 );
628
629 Self {
630 device,
631 queue,
632 mapping_error: Arc::new(Mutex::new(None)),
633 tainted: false,
634 initial_state_host,
635 initial_state_gpu,
636 proofs_host,
637 proofs_gpu,
638 bind_group_compute_f1,
639 compute_pipeline_compute_f1,
640 bind_group_sort_buckets_a,
641 compute_pipeline_sort_buckets_a,
642 bind_group_sort_buckets_b,
643 compute_pipeline_sort_buckets_b,
644 bind_group_find_matches_and_compute_f2,
645 compute_pipeline_find_matches_and_compute_f2,
646 bind_group_find_matches_and_compute_f3,
647 compute_pipeline_find_matches_and_compute_f3,
648 bind_group_find_matches_and_compute_f4,
649 compute_pipeline_find_matches_and_compute_f4,
650 bind_group_find_matches_and_compute_f5,
651 compute_pipeline_find_matches_and_compute_f5,
652 bind_group_find_matches_and_compute_f6,
653 compute_pipeline_find_matches_and_compute_f6,
654 bind_group_find_matches_and_compute_f7,
655 compute_pipeline_find_matches_and_compute_f7,
656 bind_group_find_proofs,
657 compute_pipeline_find_proofs,
658 }
659 }
660
661 fn create_proofs(
662 &mut self,
663 seed: &PosSeed,
664 ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
665 if self.tainted {
666 return Err(RecordEncodingError::DeviceBroken);
667 }
668 self.tainted = true;
669
670 let mut encoder = self
671 .device
672 .create_command_encoder(&CommandEncoderDescriptor {
673 label: Some("create_proofs"),
674 });
675
676 self.initial_state_host
678 .get_mapped_range_mut(..)
679 .copy_from_slice(&block_to_bytes(
680 &ChaCha8State::init(seed, &[0; _]).to_repr(),
681 ));
682 self.initial_state_host.unmap();
683
684 encoder.copy_buffer_to_buffer(
685 &self.initial_state_host,
686 0,
687 &self.initial_state_gpu,
688 0,
689 self.initial_state_host.size(),
690 );
691
692 {
693 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
694 label: Some("create_proofs"),
695 timestamp_writes: None,
696 });
697
698 cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
699 cpass.set_pipeline(&self.compute_pipeline_compute_f1);
700 cpass.dispatch_workgroups(
701 MAX_TABLE_SIZE
702 .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
703 1,
704 1,
705 );
706
707 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
708 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
709 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
710
711 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
712 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
713 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
714
715 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
716 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
717 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
718
719 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
720 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
721 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
722
723 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
724 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
725 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
726
727 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
728 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
729 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
730
731 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
732 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
733 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
734
735 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
736 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
737 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
738
739 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
740 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
741 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
742
743 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
744 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
745 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
746
747 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
748 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
749 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
750
751 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
752 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
753 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
754
755 cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
756 cpass.set_pipeline(&self.compute_pipeline_find_proofs);
757 cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
758 }
759
760 encoder.copy_buffer_to_buffer(
761 &self.proofs_gpu,
762 0,
763 &self.proofs_host,
764 0,
765 self.proofs_host.size(),
766 );
767
768 encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
770 let mapping_error = Arc::clone(&self.mapping_error);
771
772 move |r| {
773 if let Err(error) = r {
774 mapping_error.lock().replace(error);
775 }
776 }
777 });
778 encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
779 let mapping_error = Arc::clone(&self.mapping_error);
780
781 move |r| {
782 if let Err(error) = r {
783 mapping_error.lock().replace(error);
784 }
785 }
786 });
787
788 let submission_index = self.queue.submit([encoder.finish()]);
789
790 self.device.poll(PollType::Wait {
791 submission_index: Some(submission_index),
792 timeout: None,
793 })?;
794
795 if let Some(error) = self.mapping_error.lock().take() {
796 return Err(RecordEncodingError::BufferMapping(error));
797 }
798
799 let proofs = {
800 let proofs_host_ptr = self
801 .proofs_host
802 .get_mapped_range(..)
803 .as_ptr()
804 .cast::<ProofsHost>();
805 unsafe { &*proofs_host_ptr }
807 };
808
809 self.tainted = false;
810
811 Ok(ProofsHostWrapper {
812 proofs,
813 proofs_host: &self.proofs_host,
814 })
815 }
816}
817
818fn bind_group_and_pipeline_compute_f1(
819 device: &wgpu::Device,
820 module: &ShaderModule,
821 initial_state_gpu: &Buffer,
822 bucket_sizes_gpu: &Buffer,
823 buckets_gpu: &Buffer,
824) -> (BindGroup, ComputePipeline) {
825 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
826 label: Some("compute_f1"),
827 entries: &[
828 BindGroupLayoutEntry {
829 binding: 0,
830 count: None,
831 visibility: ShaderStages::COMPUTE,
832 ty: BindingType::Buffer {
833 has_dynamic_offset: false,
834 min_binding_size: None,
835 ty: BufferBindingType::Uniform,
836 },
837 },
838 BindGroupLayoutEntry {
839 binding: 1,
840 count: None,
841 visibility: ShaderStages::COMPUTE,
842 ty: BindingType::Buffer {
843 has_dynamic_offset: false,
844 min_binding_size: None,
845 ty: BufferBindingType::Storage { read_only: false },
846 },
847 },
848 BindGroupLayoutEntry {
849 binding: 2,
850 count: None,
851 visibility: ShaderStages::COMPUTE,
852 ty: BindingType::Buffer {
853 has_dynamic_offset: false,
854 min_binding_size: None,
855 ty: BufferBindingType::Storage { read_only: false },
856 },
857 },
858 ],
859 });
860
861 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
862 label: Some("compute_f1"),
863 bind_group_layouts: &[Some(&bind_group_layout)],
864 immediate_size: 0,
865 });
866
867 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
868 compilation_options: PipelineCompilationOptions {
869 constants: &[],
870 zero_initialize_workgroup_memory: false,
871 },
872 cache: None,
873 label: Some("compute_f1"),
874 layout: Some(&pipeline_layout),
875 module,
876 entry_point: Some("compute_f1"),
877 });
878
879 let bind_group = device.create_bind_group(&BindGroupDescriptor {
880 label: Some("compute_f1"),
881 layout: &bind_group_layout,
882 entries: &[
883 BindGroupEntry {
884 binding: 0,
885 resource: initial_state_gpu.as_entire_binding(),
886 },
887 BindGroupEntry {
888 binding: 1,
889 resource: bucket_sizes_gpu.as_entire_binding(),
890 },
891 BindGroupEntry {
892 binding: 2,
893 resource: buckets_gpu.as_entire_binding(),
894 },
895 ],
896 });
897
898 (bind_group, compute_pipeline)
899}
900
901fn bind_group_and_pipeline_sort_buckets(
902 device: &wgpu::Device,
903 module: &ShaderModule,
904 bucket_sizes_gpu: &Buffer,
905 buckets_gpu: &Buffer,
906) -> (BindGroup, ComputePipeline) {
907 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
908 label: Some("sort_buckets"),
909 entries: &[
910 BindGroupLayoutEntry {
911 binding: 0,
912 count: None,
913 visibility: ShaderStages::COMPUTE,
914 ty: BindingType::Buffer {
915 has_dynamic_offset: false,
916 min_binding_size: None,
917 ty: BufferBindingType::Storage { read_only: false },
918 },
919 },
920 BindGroupLayoutEntry {
921 binding: 1,
922 count: None,
923 visibility: ShaderStages::COMPUTE,
924 ty: BindingType::Buffer {
925 has_dynamic_offset: false,
926 min_binding_size: None,
927 ty: BufferBindingType::Storage { read_only: false },
928 },
929 },
930 ],
931 });
932
933 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
934 label: Some("sort_buckets"),
935 bind_group_layouts: &[Some(&bind_group_layout)],
936 immediate_size: 0,
937 });
938
939 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
940 compilation_options: PipelineCompilationOptions {
941 constants: &[],
942 zero_initialize_workgroup_memory: false,
943 },
944 cache: None,
945 label: Some("sort_buckets"),
946 layout: Some(&pipeline_layout),
947 module,
948 entry_point: Some("sort_buckets"),
949 });
950
951 let bind_group = device.create_bind_group(&BindGroupDescriptor {
952 label: Some("sort_buckets"),
953 layout: &bind_group_layout,
954 entries: &[
955 BindGroupEntry {
956 binding: 0,
957 resource: bucket_sizes_gpu.as_entire_binding(),
958 },
959 BindGroupEntry {
960 binding: 1,
961 resource: buckets_gpu.as_entire_binding(),
962 },
963 ],
964 });
965
966 (bind_group, compute_pipeline)
967}
968
969fn bind_group_and_pipeline_find_matches_and_compute_f2(
970 device: &wgpu::Device,
971 module: &ShaderModule,
972 parent_buckets_gpu: &Buffer,
973 bucket_sizes_gpu: &Buffer,
974 buckets_gpu: &Buffer,
975 positions_gpu: &Buffer,
976 metadatas_gpu: &Buffer,
977) -> (BindGroup, ComputePipeline) {
978 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
979 label: Some("find_matches_and_compute_f2"),
980 entries: &[
981 BindGroupLayoutEntry {
982 binding: 0,
983 count: None,
984 visibility: ShaderStages::COMPUTE,
985 ty: BindingType::Buffer {
986 has_dynamic_offset: false,
987 min_binding_size: None,
988 ty: BufferBindingType::Storage { read_only: true },
989 },
990 },
991 BindGroupLayoutEntry {
992 binding: 1,
993 count: None,
994 visibility: ShaderStages::COMPUTE,
995 ty: BindingType::Buffer {
996 has_dynamic_offset: false,
997 min_binding_size: None,
998 ty: BufferBindingType::Storage { read_only: false },
999 },
1000 },
1001 BindGroupLayoutEntry {
1002 binding: 2,
1003 count: None,
1004 visibility: ShaderStages::COMPUTE,
1005 ty: BindingType::Buffer {
1006 has_dynamic_offset: false,
1007 min_binding_size: None,
1008 ty: BufferBindingType::Storage { read_only: false },
1009 },
1010 },
1011 BindGroupLayoutEntry {
1012 binding: 3,
1013 count: None,
1014 visibility: ShaderStages::COMPUTE,
1015 ty: BindingType::Buffer {
1016 has_dynamic_offset: false,
1017 min_binding_size: None,
1018 ty: BufferBindingType::Storage { read_only: false },
1019 },
1020 },
1021 BindGroupLayoutEntry {
1022 binding: 4,
1023 count: None,
1024 visibility: ShaderStages::COMPUTE,
1025 ty: BindingType::Buffer {
1026 has_dynamic_offset: false,
1027 min_binding_size: None,
1028 ty: BufferBindingType::Storage { read_only: false },
1029 },
1030 },
1031 ],
1032 });
1033
1034 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1035 label: Some("find_matches_and_compute_f2"),
1036 bind_group_layouts: &[Some(&bind_group_layout)],
1037 immediate_size: 0,
1038 });
1039
1040 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1041 compilation_options: PipelineCompilationOptions {
1042 constants: &[],
1043 zero_initialize_workgroup_memory: true,
1044 },
1045 cache: None,
1046 label: Some("find_matches_and_compute_f2"),
1047 layout: Some(&pipeline_layout),
1048 module,
1049 entry_point: Some("find_matches_and_compute_f2"),
1050 });
1051
1052 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1053 label: Some("find_matches_and_compute_f2"),
1054 layout: &bind_group_layout,
1055 entries: &[
1056 BindGroupEntry {
1057 binding: 0,
1058 resource: parent_buckets_gpu.as_entire_binding(),
1059 },
1060 BindGroupEntry {
1061 binding: 1,
1062 resource: bucket_sizes_gpu.as_entire_binding(),
1063 },
1064 BindGroupEntry {
1065 binding: 2,
1066 resource: buckets_gpu.as_entire_binding(),
1067 },
1068 BindGroupEntry {
1069 binding: 3,
1070 resource: positions_gpu.as_entire_binding(),
1071 },
1072 BindGroupEntry {
1073 binding: 4,
1074 resource: metadatas_gpu.as_entire_binding(),
1075 },
1076 ],
1077 });
1078
1079 (bind_group, compute_pipeline)
1080}
1081
1082#[expect(
1083 clippy::too_many_arguments,
1084 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1085)]
1086fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1087 device: &wgpu::Device,
1088 module: &ShaderModule,
1089 parent_buckets_gpu: &Buffer,
1090 parent_metadatas_gpu: &Buffer,
1091 bucket_sizes_gpu: &Buffer,
1092 buckets_gpu: &Buffer,
1093 positions_gpu: &Buffer,
1094 metadatas_gpu: &Buffer,
1095) -> (BindGroup, ComputePipeline) {
1096 let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1097 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1098 label: Some(&label),
1099 entries: &[
1100 BindGroupLayoutEntry {
1101 binding: 0,
1102 count: None,
1103 visibility: ShaderStages::COMPUTE,
1104 ty: BindingType::Buffer {
1105 has_dynamic_offset: false,
1106 min_binding_size: None,
1107 ty: BufferBindingType::Storage { read_only: true },
1108 },
1109 },
1110 BindGroupLayoutEntry {
1111 binding: 1,
1112 count: None,
1113 visibility: ShaderStages::COMPUTE,
1114 ty: BindingType::Buffer {
1115 has_dynamic_offset: false,
1116 min_binding_size: None,
1117 ty: BufferBindingType::Storage { read_only: true },
1118 },
1119 },
1120 BindGroupLayoutEntry {
1121 binding: 2,
1122 count: None,
1123 visibility: ShaderStages::COMPUTE,
1124 ty: BindingType::Buffer {
1125 has_dynamic_offset: false,
1126 min_binding_size: None,
1127 ty: BufferBindingType::Storage { read_only: false },
1128 },
1129 },
1130 BindGroupLayoutEntry {
1131 binding: 3,
1132 count: None,
1133 visibility: ShaderStages::COMPUTE,
1134 ty: BindingType::Buffer {
1135 has_dynamic_offset: false,
1136 min_binding_size: None,
1137 ty: BufferBindingType::Storage { read_only: false },
1138 },
1139 },
1140 BindGroupLayoutEntry {
1141 binding: 4,
1142 count: None,
1143 visibility: ShaderStages::COMPUTE,
1144 ty: BindingType::Buffer {
1145 has_dynamic_offset: false,
1146 min_binding_size: None,
1147 ty: BufferBindingType::Storage { read_only: false },
1148 },
1149 },
1150 BindGroupLayoutEntry {
1151 binding: 5,
1152 count: None,
1153 visibility: ShaderStages::COMPUTE,
1154 ty: BindingType::Buffer {
1155 has_dynamic_offset: false,
1156 min_binding_size: None,
1157 ty: BufferBindingType::Storage { read_only: false },
1158 },
1159 },
1160 ],
1161 });
1162
1163 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1164 label: Some(&label),
1165 bind_group_layouts: &[Some(&bind_group_layout)],
1166 immediate_size: 0,
1167 });
1168
1169 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1170 compilation_options: PipelineCompilationOptions {
1171 constants: &[],
1172 zero_initialize_workgroup_memory: true,
1173 },
1174 cache: None,
1175 label: Some(&label),
1176 layout: Some(&pipeline_layout),
1177 module,
1178 entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1179 });
1180
1181 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1182 label: Some(&label),
1183 layout: &bind_group_layout,
1184 entries: &[
1185 BindGroupEntry {
1186 binding: 0,
1187 resource: parent_buckets_gpu.as_entire_binding(),
1188 },
1189 BindGroupEntry {
1190 binding: 1,
1191 resource: parent_metadatas_gpu.as_entire_binding(),
1192 },
1193 BindGroupEntry {
1194 binding: 2,
1195 resource: bucket_sizes_gpu.as_entire_binding(),
1196 },
1197 BindGroupEntry {
1198 binding: 3,
1199 resource: buckets_gpu.as_entire_binding(),
1200 },
1201 BindGroupEntry {
1202 binding: 4,
1203 resource: positions_gpu.as_entire_binding(),
1204 },
1205 BindGroupEntry {
1206 binding: 5,
1207 resource: metadatas_gpu.as_entire_binding(),
1208 },
1209 ],
1210 });
1211
1212 (bind_group, compute_pipeline)
1213}
1214
1215fn bind_group_and_pipeline_find_matches_and_compute_f7(
1216 device: &wgpu::Device,
1217 module: &ShaderModule,
1218 parent_buckets_gpu: &Buffer,
1219 parent_metadatas_gpu: &Buffer,
1220 table_6_proof_targets_sizes_gpu: &Buffer,
1221 table_6_proof_targets_gpu: &Buffer,
1222) -> (BindGroup, ComputePipeline) {
1223 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1224 label: Some("find_matches_and_compute_f7"),
1225 entries: &[
1226 BindGroupLayoutEntry {
1227 binding: 0,
1228 count: None,
1229 visibility: ShaderStages::COMPUTE,
1230 ty: BindingType::Buffer {
1231 has_dynamic_offset: false,
1232 min_binding_size: None,
1233 ty: BufferBindingType::Storage { read_only: true },
1234 },
1235 },
1236 BindGroupLayoutEntry {
1237 binding: 1,
1238 count: None,
1239 visibility: ShaderStages::COMPUTE,
1240 ty: BindingType::Buffer {
1241 has_dynamic_offset: false,
1242 min_binding_size: None,
1243 ty: BufferBindingType::Storage { read_only: true },
1244 },
1245 },
1246 BindGroupLayoutEntry {
1247 binding: 2,
1248 count: None,
1249 visibility: ShaderStages::COMPUTE,
1250 ty: BindingType::Buffer {
1251 has_dynamic_offset: false,
1252 min_binding_size: None,
1253 ty: BufferBindingType::Storage { read_only: false },
1254 },
1255 },
1256 BindGroupLayoutEntry {
1257 binding: 3,
1258 count: None,
1259 visibility: ShaderStages::COMPUTE,
1260 ty: BindingType::Buffer {
1261 has_dynamic_offset: false,
1262 min_binding_size: None,
1263 ty: BufferBindingType::Storage { read_only: false },
1264 },
1265 },
1266 ],
1267 });
1268
1269 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1270 label: Some("find_matches_and_compute_f7"),
1271 bind_group_layouts: &[Some(&bind_group_layout)],
1272 immediate_size: 0,
1273 });
1274
1275 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1276 compilation_options: PipelineCompilationOptions {
1277 constants: &[],
1278 zero_initialize_workgroup_memory: true,
1279 },
1280 cache: None,
1281 label: Some("find_matches_and_compute_f7"),
1282 layout: Some(&pipeline_layout),
1283 module,
1284 entry_point: Some("find_matches_and_compute_f7"),
1285 });
1286
1287 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1288 label: Some("find_matches_and_compute_f7"),
1289 layout: &bind_group_layout,
1290 entries: &[
1291 BindGroupEntry {
1292 binding: 0,
1293 resource: parent_buckets_gpu.as_entire_binding(),
1294 },
1295 BindGroupEntry {
1296 binding: 1,
1297 resource: parent_metadatas_gpu.as_entire_binding(),
1298 },
1299 BindGroupEntry {
1300 binding: 2,
1301 resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1302 },
1303 BindGroupEntry {
1304 binding: 3,
1305 resource: table_6_proof_targets_gpu.as_entire_binding(),
1306 },
1307 ],
1308 });
1309
1310 (bind_group, compute_pipeline)
1311}
1312
1313#[expect(
1314 clippy::too_many_arguments,
1315 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1316)]
1317fn bind_group_and_pipeline_find_proofs(
1318 device: &wgpu::Device,
1319 module: &ShaderModule,
1320 table_2_positions_gpu: &Buffer,
1321 table_3_positions_gpu: &Buffer,
1322 table_4_positions_gpu: &Buffer,
1323 table_5_positions_gpu: &Buffer,
1324 table_6_positions_gpu: &Buffer,
1325 bucket_sizes_gpu: &Buffer,
1326 buckets_gpu: &Buffer,
1327 proofs_gpu: &Buffer,
1328) -> (BindGroup, ComputePipeline) {
1329 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1330 label: Some("find_proofs"),
1331 entries: &[
1332 BindGroupLayoutEntry {
1333 binding: 0,
1334 count: None,
1335 visibility: ShaderStages::COMPUTE,
1336 ty: BindingType::Buffer {
1337 has_dynamic_offset: false,
1338 min_binding_size: None,
1339 ty: BufferBindingType::Storage { read_only: true },
1340 },
1341 },
1342 BindGroupLayoutEntry {
1343 binding: 1,
1344 count: None,
1345 visibility: ShaderStages::COMPUTE,
1346 ty: BindingType::Buffer {
1347 has_dynamic_offset: false,
1348 min_binding_size: None,
1349 ty: BufferBindingType::Storage { read_only: true },
1350 },
1351 },
1352 BindGroupLayoutEntry {
1353 binding: 2,
1354 count: None,
1355 visibility: ShaderStages::COMPUTE,
1356 ty: BindingType::Buffer {
1357 has_dynamic_offset: false,
1358 min_binding_size: None,
1359 ty: BufferBindingType::Storage { read_only: true },
1360 },
1361 },
1362 BindGroupLayoutEntry {
1363 binding: 3,
1364 count: None,
1365 visibility: ShaderStages::COMPUTE,
1366 ty: BindingType::Buffer {
1367 has_dynamic_offset: false,
1368 min_binding_size: None,
1369 ty: BufferBindingType::Storage { read_only: true },
1370 },
1371 },
1372 BindGroupLayoutEntry {
1373 binding: 4,
1374 count: None,
1375 visibility: ShaderStages::COMPUTE,
1376 ty: BindingType::Buffer {
1377 has_dynamic_offset: false,
1378 min_binding_size: None,
1379 ty: BufferBindingType::Storage { read_only: true },
1380 },
1381 },
1382 BindGroupLayoutEntry {
1383 binding: 5,
1384 count: None,
1385 visibility: ShaderStages::COMPUTE,
1386 ty: BindingType::Buffer {
1387 has_dynamic_offset: false,
1388 min_binding_size: None,
1389 ty: BufferBindingType::Storage { read_only: false },
1390 },
1391 },
1392 BindGroupLayoutEntry {
1393 binding: 6,
1394 count: None,
1395 visibility: ShaderStages::COMPUTE,
1396 ty: BindingType::Buffer {
1397 has_dynamic_offset: false,
1398 min_binding_size: None,
1399 ty: BufferBindingType::Storage { read_only: true },
1400 },
1401 },
1402 BindGroupLayoutEntry {
1403 binding: 7,
1404 count: None,
1405 visibility: ShaderStages::COMPUTE,
1406 ty: BindingType::Buffer {
1407 has_dynamic_offset: false,
1408 min_binding_size: None,
1409 ty: BufferBindingType::Storage { read_only: false },
1410 },
1411 },
1412 ],
1413 });
1414
1415 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1416 label: Some("find_proofs"),
1417 bind_group_layouts: &[Some(&bind_group_layout)],
1418 immediate_size: 0,
1419 });
1420
1421 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1422 compilation_options: PipelineCompilationOptions {
1423 constants: &[],
1424 zero_initialize_workgroup_memory: false,
1425 },
1426 cache: None,
1427 label: Some("find_proofs"),
1428 layout: Some(&pipeline_layout),
1429 module,
1430 entry_point: Some("find_proofs"),
1431 });
1432
1433 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1434 label: Some("find_proofs"),
1435 layout: &bind_group_layout,
1436 entries: &[
1437 BindGroupEntry {
1438 binding: 0,
1439 resource: table_2_positions_gpu.as_entire_binding(),
1440 },
1441 BindGroupEntry {
1442 binding: 1,
1443 resource: table_3_positions_gpu.as_entire_binding(),
1444 },
1445 BindGroupEntry {
1446 binding: 2,
1447 resource: table_4_positions_gpu.as_entire_binding(),
1448 },
1449 BindGroupEntry {
1450 binding: 3,
1451 resource: table_5_positions_gpu.as_entire_binding(),
1452 },
1453 BindGroupEntry {
1454 binding: 4,
1455 resource: table_6_positions_gpu.as_entire_binding(),
1456 },
1457 BindGroupEntry {
1458 binding: 5,
1459 resource: bucket_sizes_gpu.as_entire_binding(),
1460 },
1461 BindGroupEntry {
1462 binding: 6,
1463 resource: buckets_gpu.as_entire_binding(),
1464 },
1465 BindGroupEntry {
1466 binding: 7,
1467 resource: proofs_gpu.as_entire_binding(),
1468 },
1469 ],
1470 });
1471
1472 (bind_group, compute_pipeline)
1473}