1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5 MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6 REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::num::NonZeroU8;
27use std::simd::Simd;
28use std::sync::Arc as StdArc;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::{fmt, iter};
31use tracing::{debug, warn};
32use wgpu::{
33 AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35 BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36 ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37 DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38 PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39 RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45 #[error("Too many records: {0}")]
47 TooManyRecords(usize),
48 #[error("Proof creation failed previously and the device is now considered broken")]
50 DeviceBroken,
51 #[error("Failed to map buffer: {0}")]
53 BufferMapping(#[from] BufferAsyncError),
54 #[error("Poll error: {0}")]
56 DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60 proofs: &'a ProofsHost,
61 proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65 fn drop(&mut self) {
66 self.proofs_host.unmap();
67 }
68}
69
70#[derive(Clone)]
72pub struct Device {
73 id: u32,
74 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75 adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("Device")
81 .field("id", &self.id)
82 .field("name", &self.adapter_info.name)
83 .field("device_type", &self.adapter_info.device_type)
84 .field("driver", &self.adapter_info.driver)
85 .field("driver_info", &self.adapter_info.driver_info)
86 .field("backend", &self.adapter_info.backend)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Device {
92 pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94 where
95 NOQ: Fn(DeviceType) -> NonZeroU8,
96 {
97 let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98 let instance = Instance::new(InstanceDescriptor {
99 backends,
100 flags: if cfg!(debug_assertions) {
101 InstanceFlags::debugging().with_env()
102 } else {
103 InstanceFlags::from_env_or_default()
104 },
105 memory_budget_thresholds: MemoryBudgetThresholds::default(),
106 backend_options: BackendOptions::from_env_or_default(),
107 display: None,
108 });
109
110 let adapters = instance.enumerate_adapters(backends).await;
111 let number_of_queues = &number_of_queues;
112
113 adapters
114 .into_iter()
115 .zip(0..)
116 .map(|(adapter, id)| async move {
117 let adapter_info = adapter.get_info();
118
119 let (shader, required_features, required_limits) =
120 if let Some((shader, required_features, required_limits)) =
121 select_shader_features_limits(&adapter)
122 {
123 debug!(
124 %id,
125 adapter_info = ?adapter_info,
126 "Compatible adapter found"
127 );
128
129 (shader, required_features, required_limits)
130 } else {
131 debug!(
132 %id,
133 adapter_info = ?adapter_info,
134 "Incompatible adapter found"
135 );
136
137 return None;
138 };
139
140 let devices = iter::repeat_with(|| async {
143 let (device, queue) = adapter
144 .request_device(&DeviceDescriptor {
145 label: None,
146 required_features,
147 required_limits: required_limits.clone(),
148 ..DeviceDescriptor::default()
149 })
150 .await
151 .inspect_err(|error| {
152 warn!(%id, ?adapter_info, %error, "Failed to request the device");
153 })?;
154 let module = device.create_shader_module(shader.clone());
155
156 Ok::<_, RequestDeviceError>((device, queue, module))
157 })
158 .take(usize::from(
159 number_of_queues(adapter_info.device_type).get(),
160 ))
161 .collect::<FuturesOrdered<_>>()
162 .try_collect::<Vec<_>>()
163 .await
164 .ok()?;
165
166 Some(Self {
167 id,
168 devices,
169 adapter_info,
170 })
171 })
172 .collect::<FuturesOrdered<_>>()
173 .filter_map(|device| async move { device })
174 .collect()
175 .await
176 }
177
178 pub fn id(&self) -> u32 {
180 self.id
181 }
182
183 pub fn name(&self) -> &str {
185 &self.adapter_info.name
186 }
187
188 pub fn device_type(&self) -> DeviceType {
190 self.adapter_info.device_type
191 }
192
193 pub fn driver(&self) -> &str {
195 &self.adapter_info.driver
196 }
197
198 pub fn driver_info(&self) -> &str {
200 &self.adapter_info.driver_info
201 }
202
203 pub fn backend(&self) -> Backend {
205 self.adapter_info.backend
206 }
207
208 pub fn instantiate(
209 &self,
210 erasure_coding: ErasureCoding,
211 global_mutex: StdArc<AsyncMutex<()>>,
212 ) -> anyhow::Result<GpuRecordsEncoder> {
213 GpuRecordsEncoder::new(
214 self.id,
215 self.devices.clone(),
216 self.adapter_info.clone(),
217 erasure_coding,
218 global_mutex,
219 )
220 }
221}
222
223pub struct GpuRecordsEncoder {
224 id: u32,
225 instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
226 thread_pool: ThreadPool,
227 adapter_info: AdapterInfo,
228 erasure_coding: ErasureCoding,
229 global_mutex: StdArc<AsyncMutex<()>>,
230}
231
232impl fmt::Debug for GpuRecordsEncoder {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 f.debug_struct("GpuRecordsEncoder")
235 .field("id", &self.id)
236 .field("name", &self.adapter_info.name)
237 .field("device_type", &self.adapter_info.device_type)
238 .field("driver", &self.adapter_info.driver)
239 .field("driver_info", &self.adapter_info.driver_info)
240 .field("backend", &self.adapter_info.backend)
241 .finish_non_exhaustive()
242 }
243}
244
245impl RecordsEncoder for GpuRecordsEncoder {
246 fn encode_records(
248 &mut self,
249 sector_id: &SectorId,
250 records: &mut [Record],
251 abort_early: &AtomicBool,
252 ) -> anyhow::Result<SectorContentsMap> {
253 let mut sector_contents_map = SectorContentsMap::new(
254 u16::try_from(records.len())
255 .map_err(|_error| RecordEncodingError::TooManyRecords(records.len()))?,
256 );
257
258 let maybe_error = self.thread_pool.install(|| {
259 records
260 .par_iter_mut()
261 .zip(sector_contents_map.iter_record_chunks_used_mut())
262 .enumerate()
263 .find_map_any(|(piece_offset, (record, record_chunks_used))| {
264 self.global_mutex.lock_blocking();
266
267 let mut parity_record_chunks = Record::new_boxed();
268
269 self.erasure_coding
272 .extend(record.iter(), parity_record_chunks.iter_mut())
273 .expect("Statically guaranteed valid inputs; qed");
274
275 if abort_early.load(Ordering::Relaxed) {
276 return None;
277 }
278 let seed =
279 sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
280 let thread_index = rayon::current_thread_index().unwrap_or_default();
281 let mut encoder_instance = self.instances[thread_index]
282 .try_lock()
283 .expect("1:1 mapping between threads and devices; qed");
284 let proofs = match encoder_instance.create_proofs(&seed) {
285 Ok(proofs) => proofs,
286 Err(error) => {
287 return Some(error);
288 }
289 };
290 let proofs = proofs.proofs;
291
292 *record_chunks_used = proofs.found_proofs;
293
294 let mut num_found_proofs = 0_usize;
296 for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
297 .array_chunks::<{ u8::BITS as usize }>()
298 .zip(record_chunks_used)
299 {
300 for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
301 if num_found_proofs == Record::NUM_CHUNKS {
302 *found_proofs &=
304 u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
305 break;
306 }
307 if (*found_proofs & (1 << proof_offset)) != 0 {
308 let record_chunk = if s_bucket < Record::NUM_CHUNKS {
309 record[s_bucket]
310 } else {
311 parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
312 };
313
314 record[num_found_proofs] = (Simd::from(record_chunk)
315 ^ Simd::from(*proofs.proofs[s_bucket].hash()))
316 .to_array();
317 num_found_proofs += 1;
318 }
319 }
320 }
321
322 None
323 })
324 });
325
326 if let Some(error) = maybe_error {
327 return Err(error.into());
328 }
329
330 Ok(sector_contents_map)
331 }
332}
333
334impl GpuRecordsEncoder {
335 fn new(
336 id: u32,
337 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
338 adapter_info: AdapterInfo,
339 erasure_coding: ErasureCoding,
340 global_mutex: StdArc<AsyncMutex<()>>,
341 ) -> anyhow::Result<Self> {
342 let thread_pool = ThreadPoolBuilder::new()
343 .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
344 .num_threads(devices.len())
345 .build()?;
346
347 Ok(Self {
348 id,
349 instances: devices
350 .into_iter()
351 .map(|(device, queue, module)| {
352 Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
353 })
354 .collect(),
355 thread_pool,
356 adapter_info,
357 erasure_coding,
358 global_mutex,
359 })
360 }
361}
362
363struct GpuRecordsEncoderInstance {
364 device: wgpu::Device,
365 queue: Queue,
366 mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
367 tainted: bool,
368 initial_state_host: Buffer,
369 initial_state_gpu: Buffer,
370 proofs_host: Buffer,
371 proofs_gpu: Buffer,
372 bind_group_compute_f1: BindGroup,
373 compute_pipeline_compute_f1: ComputePipeline,
374 bind_group_sort_buckets_a: BindGroup,
375 compute_pipeline_sort_buckets_a: ComputePipeline,
376 bind_group_sort_buckets_b: BindGroup,
377 compute_pipeline_sort_buckets_b: ComputePipeline,
378 bind_group_find_matches_and_compute_f2: BindGroup,
379 compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
380 bind_group_find_matches_and_compute_f3: BindGroup,
381 compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
382 bind_group_find_matches_and_compute_f4: BindGroup,
383 compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
384 bind_group_find_matches_and_compute_f5: BindGroup,
385 compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
386 bind_group_find_matches_and_compute_f6: BindGroup,
387 compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
388 bind_group_find_matches_and_compute_f7: BindGroup,
389 compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
390 bind_group_find_proofs: BindGroup,
391 compute_pipeline_find_proofs: ComputePipeline,
392}
393
394impl GpuRecordsEncoderInstance {
395 fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
396 let initial_state_host = device.create_buffer(&BufferDescriptor {
397 label: Some("initial_state_host"),
398 size: size_of::<ChaCha8Block>() as BufferAddress,
399 usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
400 mapped_at_creation: true,
401 });
402
403 let initial_state_gpu = device.create_buffer(&BufferDescriptor {
404 label: Some("initial_state_gpu"),
405 size: initial_state_host.size(),
406 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
407 mapped_at_creation: false,
408 });
409
410 let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
411 let table_6_proof_targets_sizes_gpu_buffer_size =
412 size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
413 let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
417 label: Some("bucket_sizes_gpu"),
418 size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
419 usage: BufferUsages::STORAGE,
420 mapped_at_creation: false,
421 });
422 let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
424
425 let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
426 label: Some("buckets_a_gpu"),
427 size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
428 usage: BufferUsages::STORAGE,
429 mapped_at_creation: false,
430 });
431
432 let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
433 label: Some("buckets_b_gpu"),
434 size: buckets_a_gpu.size(),
435 usage: BufferUsages::STORAGE,
436 mapped_at_creation: false,
437 });
438
439 let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
440 label: Some("positions_f2_gpu"),
441 size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
442 as BufferAddress,
443 usage: BufferUsages::STORAGE,
444 mapped_at_creation: false,
445 });
446
447 let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
448 label: Some("positions_f3_gpu"),
449 size: positions_f2_gpu.size(),
450 usage: BufferUsages::STORAGE,
451 mapped_at_creation: false,
452 });
453
454 let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
455 label: Some("positions_f4_gpu"),
456 size: positions_f2_gpu.size(),
457 usage: BufferUsages::STORAGE,
458 mapped_at_creation: false,
459 });
460
461 let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
462 label: Some("positions_f5_gpu"),
463 size: positions_f2_gpu.size(),
464 usage: BufferUsages::STORAGE,
465 mapped_at_creation: false,
466 });
467
468 let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
469 label: Some("positions_f6_gpu"),
470 size: positions_f2_gpu.size(),
471 usage: BufferUsages::STORAGE,
472 mapped_at_creation: false,
473 });
474
475 let metadatas_gpu_buffer_size =
476 size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
477 let table_6_proof_targets_gpu_buffer_size = size_of::<
478 [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
479 >() as BufferAddress;
480 let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
481 label: Some("metadatas_a_gpu"),
482 size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
483 usage: BufferUsages::STORAGE,
484 mapped_at_creation: false,
485 });
486 let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
488
489 let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
490 label: Some("metadatas_b_gpu"),
491 size: metadatas_gpu_buffer_size,
492 usage: BufferUsages::STORAGE,
493 mapped_at_creation: false,
494 });
495
496 let proofs_host = device.create_buffer(&BufferDescriptor {
497 label: Some("proofs_host"),
498 size: size_of::<ProofsHost>() as BufferAddress,
499 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
500 mapped_at_creation: false,
501 });
502
503 let proofs_gpu = device.create_buffer(&BufferDescriptor {
504 label: Some("proofs_gpu"),
505 size: proofs_host.size(),
506 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
507 mapped_at_creation: false,
508 });
509
510 let (bind_group_compute_f1, compute_pipeline_compute_f1) =
511 bind_group_and_pipeline_compute_f1(
512 &device,
513 &module,
514 &initial_state_gpu,
515 &bucket_sizes_gpu,
516 &buckets_a_gpu,
517 );
518 let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
519 bind_group_and_pipeline_sort_buckets(
520 &device,
521 &module,
522 &bucket_sizes_gpu,
523 &buckets_a_gpu,
524 );
525
526 let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
527 bind_group_and_pipeline_sort_buckets(
528 &device,
529 &module,
530 &bucket_sizes_gpu,
531 &buckets_b_gpu,
532 );
533
534 let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
535 bind_group_and_pipeline_find_matches_and_compute_f2(
536 &device,
537 &module,
538 &buckets_a_gpu,
539 &bucket_sizes_gpu,
540 &buckets_b_gpu,
541 &positions_f2_gpu,
542 &metadatas_b_gpu,
543 );
544
545 let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
546 bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
547 &device,
548 &module,
549 &buckets_b_gpu,
550 &metadatas_b_gpu,
551 &bucket_sizes_gpu,
552 &buckets_a_gpu,
553 &positions_f3_gpu,
554 &metadatas_a_gpu,
555 );
556
557 let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
558 bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
559 &device,
560 &module,
561 &buckets_a_gpu,
562 &metadatas_a_gpu,
563 &bucket_sizes_gpu,
564 &buckets_b_gpu,
565 &positions_f4_gpu,
566 &metadatas_b_gpu,
567 );
568
569 let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
570 bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
571 &device,
572 &module,
573 &buckets_b_gpu,
574 &metadatas_b_gpu,
575 &bucket_sizes_gpu,
576 &buckets_a_gpu,
577 &positions_f5_gpu,
578 &metadatas_a_gpu,
579 );
580
581 let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
582 bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
583 &device,
584 &module,
585 &buckets_a_gpu,
586 &metadatas_a_gpu,
587 &bucket_sizes_gpu,
588 &buckets_b_gpu,
589 &positions_f6_gpu,
590 &metadatas_b_gpu,
591 );
592
593 let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
594 bind_group_and_pipeline_find_matches_and_compute_f7(
595 &device,
596 &module,
597 &buckets_b_gpu,
598 &metadatas_b_gpu,
599 &table_6_proof_targets_sizes_gpu,
600 &table_6_proof_targets_gpu,
601 );
602
603 let (bind_group_find_proofs, compute_pipeline_find_proofs) =
604 bind_group_and_pipeline_find_proofs(
605 &device,
606 &module,
607 &positions_f2_gpu,
608 &positions_f3_gpu,
609 &positions_f4_gpu,
610 &positions_f5_gpu,
611 &positions_f6_gpu,
612 &table_6_proof_targets_sizes_gpu,
613 &table_6_proof_targets_gpu,
614 &proofs_gpu,
615 );
616
617 Self {
618 device,
619 queue,
620 mapping_error: Arc::new(Mutex::new(None)),
621 tainted: false,
622 initial_state_host,
623 initial_state_gpu,
624 proofs_host,
625 proofs_gpu,
626 bind_group_compute_f1,
627 compute_pipeline_compute_f1,
628 bind_group_sort_buckets_a,
629 compute_pipeline_sort_buckets_a,
630 bind_group_sort_buckets_b,
631 compute_pipeline_sort_buckets_b,
632 bind_group_find_matches_and_compute_f2,
633 compute_pipeline_find_matches_and_compute_f2,
634 bind_group_find_matches_and_compute_f3,
635 compute_pipeline_find_matches_and_compute_f3,
636 bind_group_find_matches_and_compute_f4,
637 compute_pipeline_find_matches_and_compute_f4,
638 bind_group_find_matches_and_compute_f5,
639 compute_pipeline_find_matches_and_compute_f5,
640 bind_group_find_matches_and_compute_f6,
641 compute_pipeline_find_matches_and_compute_f6,
642 bind_group_find_matches_and_compute_f7,
643 compute_pipeline_find_matches_and_compute_f7,
644 bind_group_find_proofs,
645 compute_pipeline_find_proofs,
646 }
647 }
648
649 fn create_proofs(
650 &mut self,
651 seed: &PosSeed,
652 ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
653 if self.tainted {
654 return Err(RecordEncodingError::DeviceBroken);
655 }
656 self.tainted = true;
657
658 let mut encoder = self
659 .device
660 .create_command_encoder(&CommandEncoderDescriptor {
661 label: Some("create_proofs"),
662 });
663
664 self.initial_state_host
666 .get_mapped_range_mut(..)
667 .copy_from_slice(&block_to_bytes(
668 &ChaCha8State::init(seed, &[0; _]).to_repr(),
669 ));
670 self.initial_state_host.unmap();
671
672 encoder.copy_buffer_to_buffer(
673 &self.initial_state_host,
674 0,
675 &self.initial_state_gpu,
676 0,
677 self.initial_state_host.size(),
678 );
679
680 {
681 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
682 label: Some("create_proofs"),
683 timestamp_writes: None,
684 });
685
686 cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
687 cpass.set_pipeline(&self.compute_pipeline_compute_f1);
688 cpass.dispatch_workgroups(
689 MAX_TABLE_SIZE
690 .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
691 1,
692 1,
693 );
694
695 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
696 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
697 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
698
699 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
700 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
701 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
702
703 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
704 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
705 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
706
707 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
708 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
709 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
710
711 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
712 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
713 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
714
715 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
716 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
717 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
718
719 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
720 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
721 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
722
723 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
724 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
725 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
726
727 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
728 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
729 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
730
731 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
732 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
733 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
734
735 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
736 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
737 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
738
739 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
740 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
741 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
742
743 cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
744 cpass.set_pipeline(&self.compute_pipeline_find_proofs);
745 cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
746 }
747
748 encoder.copy_buffer_to_buffer(
749 &self.proofs_gpu,
750 0,
751 &self.proofs_host,
752 0,
753 self.proofs_host.size(),
754 );
755
756 encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
758 let mapping_error = Arc::clone(&self.mapping_error);
759
760 move |r| {
761 if let Err(error) = r {
762 mapping_error.lock().replace(error);
763 }
764 }
765 });
766 encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
767 let mapping_error = Arc::clone(&self.mapping_error);
768
769 move |r| {
770 if let Err(error) = r {
771 mapping_error.lock().replace(error);
772 }
773 }
774 });
775
776 let submission_index = self.queue.submit([encoder.finish()]);
777
778 self.device.poll(PollType::Wait {
779 submission_index: Some(submission_index),
780 timeout: None,
781 })?;
782
783 if let Some(error) = self.mapping_error.lock().take() {
784 return Err(RecordEncodingError::BufferMapping(error));
785 }
786
787 let proofs = {
788 let proofs_host_ptr = self
789 .proofs_host
790 .get_mapped_range(..)
791 .as_ptr()
792 .cast::<ProofsHost>();
793 unsafe { &*proofs_host_ptr }
795 };
796
797 self.tainted = false;
798
799 Ok(ProofsHostWrapper {
800 proofs,
801 proofs_host: &self.proofs_host,
802 })
803 }
804}
805
806fn bind_group_and_pipeline_compute_f1(
807 device: &wgpu::Device,
808 module: &ShaderModule,
809 initial_state_gpu: &Buffer,
810 bucket_sizes_gpu: &Buffer,
811 buckets_gpu: &Buffer,
812) -> (BindGroup, ComputePipeline) {
813 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
814 label: Some("compute_f1"),
815 entries: &[
816 BindGroupLayoutEntry {
817 binding: 0,
818 count: None,
819 visibility: ShaderStages::COMPUTE,
820 ty: BindingType::Buffer {
821 has_dynamic_offset: false,
822 min_binding_size: None,
823 ty: BufferBindingType::Uniform,
824 },
825 },
826 BindGroupLayoutEntry {
827 binding: 1,
828 count: None,
829 visibility: ShaderStages::COMPUTE,
830 ty: BindingType::Buffer {
831 has_dynamic_offset: false,
832 min_binding_size: None,
833 ty: BufferBindingType::Storage { read_only: false },
834 },
835 },
836 BindGroupLayoutEntry {
837 binding: 2,
838 count: None,
839 visibility: ShaderStages::COMPUTE,
840 ty: BindingType::Buffer {
841 has_dynamic_offset: false,
842 min_binding_size: None,
843 ty: BufferBindingType::Storage { read_only: false },
844 },
845 },
846 ],
847 });
848
849 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
850 label: Some("compute_f1"),
851 bind_group_layouts: &[Some(&bind_group_layout)],
852 immediate_size: 0,
853 });
854
855 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
856 compilation_options: PipelineCompilationOptions {
857 constants: &[],
858 zero_initialize_workgroup_memory: false,
859 },
860 cache: None,
861 label: Some("compute_f1"),
862 layout: Some(&pipeline_layout),
863 module,
864 entry_point: Some("compute_f1"),
865 });
866
867 let bind_group = device.create_bind_group(&BindGroupDescriptor {
868 label: Some("compute_f1"),
869 layout: &bind_group_layout,
870 entries: &[
871 BindGroupEntry {
872 binding: 0,
873 resource: initial_state_gpu.as_entire_binding(),
874 },
875 BindGroupEntry {
876 binding: 1,
877 resource: bucket_sizes_gpu.as_entire_binding(),
878 },
879 BindGroupEntry {
880 binding: 2,
881 resource: buckets_gpu.as_entire_binding(),
882 },
883 ],
884 });
885
886 (bind_group, compute_pipeline)
887}
888
889fn bind_group_and_pipeline_sort_buckets(
890 device: &wgpu::Device,
891 module: &ShaderModule,
892 bucket_sizes_gpu: &Buffer,
893 buckets_gpu: &Buffer,
894) -> (BindGroup, ComputePipeline) {
895 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
896 label: Some("sort_buckets"),
897 entries: &[
898 BindGroupLayoutEntry {
899 binding: 0,
900 count: None,
901 visibility: ShaderStages::COMPUTE,
902 ty: BindingType::Buffer {
903 has_dynamic_offset: false,
904 min_binding_size: None,
905 ty: BufferBindingType::Storage { read_only: false },
906 },
907 },
908 BindGroupLayoutEntry {
909 binding: 1,
910 count: None,
911 visibility: ShaderStages::COMPUTE,
912 ty: BindingType::Buffer {
913 has_dynamic_offset: false,
914 min_binding_size: None,
915 ty: BufferBindingType::Storage { read_only: false },
916 },
917 },
918 ],
919 });
920
921 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
922 label: Some("sort_buckets"),
923 bind_group_layouts: &[Some(&bind_group_layout)],
924 immediate_size: 0,
925 });
926
927 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
928 compilation_options: PipelineCompilationOptions {
929 constants: &[],
930 zero_initialize_workgroup_memory: false,
931 },
932 cache: None,
933 label: Some("sort_buckets"),
934 layout: Some(&pipeline_layout),
935 module,
936 entry_point: Some("sort_buckets"),
937 });
938
939 let bind_group = device.create_bind_group(&BindGroupDescriptor {
940 label: Some("sort_buckets"),
941 layout: &bind_group_layout,
942 entries: &[
943 BindGroupEntry {
944 binding: 0,
945 resource: bucket_sizes_gpu.as_entire_binding(),
946 },
947 BindGroupEntry {
948 binding: 1,
949 resource: buckets_gpu.as_entire_binding(),
950 },
951 ],
952 });
953
954 (bind_group, compute_pipeline)
955}
956
957fn bind_group_and_pipeline_find_matches_and_compute_f2(
958 device: &wgpu::Device,
959 module: &ShaderModule,
960 parent_buckets_gpu: &Buffer,
961 bucket_sizes_gpu: &Buffer,
962 buckets_gpu: &Buffer,
963 positions_gpu: &Buffer,
964 metadatas_gpu: &Buffer,
965) -> (BindGroup, ComputePipeline) {
966 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
967 label: Some("find_matches_and_compute_f2"),
968 entries: &[
969 BindGroupLayoutEntry {
970 binding: 0,
971 count: None,
972 visibility: ShaderStages::COMPUTE,
973 ty: BindingType::Buffer {
974 has_dynamic_offset: false,
975 min_binding_size: None,
976 ty: BufferBindingType::Storage { read_only: true },
977 },
978 },
979 BindGroupLayoutEntry {
980 binding: 1,
981 count: None,
982 visibility: ShaderStages::COMPUTE,
983 ty: BindingType::Buffer {
984 has_dynamic_offset: false,
985 min_binding_size: None,
986 ty: BufferBindingType::Storage { read_only: false },
987 },
988 },
989 BindGroupLayoutEntry {
990 binding: 2,
991 count: None,
992 visibility: ShaderStages::COMPUTE,
993 ty: BindingType::Buffer {
994 has_dynamic_offset: false,
995 min_binding_size: None,
996 ty: BufferBindingType::Storage { read_only: false },
997 },
998 },
999 BindGroupLayoutEntry {
1000 binding: 3,
1001 count: None,
1002 visibility: ShaderStages::COMPUTE,
1003 ty: BindingType::Buffer {
1004 has_dynamic_offset: false,
1005 min_binding_size: None,
1006 ty: BufferBindingType::Storage { read_only: false },
1007 },
1008 },
1009 BindGroupLayoutEntry {
1010 binding: 4,
1011 count: None,
1012 visibility: ShaderStages::COMPUTE,
1013 ty: BindingType::Buffer {
1014 has_dynamic_offset: false,
1015 min_binding_size: None,
1016 ty: BufferBindingType::Storage { read_only: false },
1017 },
1018 },
1019 ],
1020 });
1021
1022 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1023 label: Some("find_matches_and_compute_f2"),
1024 bind_group_layouts: &[Some(&bind_group_layout)],
1025 immediate_size: 0,
1026 });
1027
1028 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1029 compilation_options: PipelineCompilationOptions {
1030 constants: &[],
1031 zero_initialize_workgroup_memory: true,
1032 },
1033 cache: None,
1034 label: Some("find_matches_and_compute_f2"),
1035 layout: Some(&pipeline_layout),
1036 module,
1037 entry_point: Some("find_matches_and_compute_f2"),
1038 });
1039
1040 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1041 label: Some("find_matches_and_compute_f2"),
1042 layout: &bind_group_layout,
1043 entries: &[
1044 BindGroupEntry {
1045 binding: 0,
1046 resource: parent_buckets_gpu.as_entire_binding(),
1047 },
1048 BindGroupEntry {
1049 binding: 1,
1050 resource: bucket_sizes_gpu.as_entire_binding(),
1051 },
1052 BindGroupEntry {
1053 binding: 2,
1054 resource: buckets_gpu.as_entire_binding(),
1055 },
1056 BindGroupEntry {
1057 binding: 3,
1058 resource: positions_gpu.as_entire_binding(),
1059 },
1060 BindGroupEntry {
1061 binding: 4,
1062 resource: metadatas_gpu.as_entire_binding(),
1063 },
1064 ],
1065 });
1066
1067 (bind_group, compute_pipeline)
1068}
1069
1070#[expect(
1071 clippy::too_many_arguments,
1072 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1073)]
1074fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1075 device: &wgpu::Device,
1076 module: &ShaderModule,
1077 parent_buckets_gpu: &Buffer,
1078 parent_metadatas_gpu: &Buffer,
1079 bucket_sizes_gpu: &Buffer,
1080 buckets_gpu: &Buffer,
1081 positions_gpu: &Buffer,
1082 metadatas_gpu: &Buffer,
1083) -> (BindGroup, ComputePipeline) {
1084 let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1085 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1086 label: Some(&label),
1087 entries: &[
1088 BindGroupLayoutEntry {
1089 binding: 0,
1090 count: None,
1091 visibility: ShaderStages::COMPUTE,
1092 ty: BindingType::Buffer {
1093 has_dynamic_offset: false,
1094 min_binding_size: None,
1095 ty: BufferBindingType::Storage { read_only: true },
1096 },
1097 },
1098 BindGroupLayoutEntry {
1099 binding: 1,
1100 count: None,
1101 visibility: ShaderStages::COMPUTE,
1102 ty: BindingType::Buffer {
1103 has_dynamic_offset: false,
1104 min_binding_size: None,
1105 ty: BufferBindingType::Storage { read_only: true },
1106 },
1107 },
1108 BindGroupLayoutEntry {
1109 binding: 2,
1110 count: None,
1111 visibility: ShaderStages::COMPUTE,
1112 ty: BindingType::Buffer {
1113 has_dynamic_offset: false,
1114 min_binding_size: None,
1115 ty: BufferBindingType::Storage { read_only: false },
1116 },
1117 },
1118 BindGroupLayoutEntry {
1119 binding: 3,
1120 count: None,
1121 visibility: ShaderStages::COMPUTE,
1122 ty: BindingType::Buffer {
1123 has_dynamic_offset: false,
1124 min_binding_size: None,
1125 ty: BufferBindingType::Storage { read_only: false },
1126 },
1127 },
1128 BindGroupLayoutEntry {
1129 binding: 4,
1130 count: None,
1131 visibility: ShaderStages::COMPUTE,
1132 ty: BindingType::Buffer {
1133 has_dynamic_offset: false,
1134 min_binding_size: None,
1135 ty: BufferBindingType::Storage { read_only: false },
1136 },
1137 },
1138 BindGroupLayoutEntry {
1139 binding: 5,
1140 count: None,
1141 visibility: ShaderStages::COMPUTE,
1142 ty: BindingType::Buffer {
1143 has_dynamic_offset: false,
1144 min_binding_size: None,
1145 ty: BufferBindingType::Storage { read_only: false },
1146 },
1147 },
1148 ],
1149 });
1150
1151 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1152 label: Some(&label),
1153 bind_group_layouts: &[Some(&bind_group_layout)],
1154 immediate_size: 0,
1155 });
1156
1157 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1158 compilation_options: PipelineCompilationOptions {
1159 constants: &[],
1160 zero_initialize_workgroup_memory: true,
1161 },
1162 cache: None,
1163 label: Some(&label),
1164 layout: Some(&pipeline_layout),
1165 module,
1166 entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1167 });
1168
1169 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1170 label: Some(&label),
1171 layout: &bind_group_layout,
1172 entries: &[
1173 BindGroupEntry {
1174 binding: 0,
1175 resource: parent_buckets_gpu.as_entire_binding(),
1176 },
1177 BindGroupEntry {
1178 binding: 1,
1179 resource: parent_metadatas_gpu.as_entire_binding(),
1180 },
1181 BindGroupEntry {
1182 binding: 2,
1183 resource: bucket_sizes_gpu.as_entire_binding(),
1184 },
1185 BindGroupEntry {
1186 binding: 3,
1187 resource: buckets_gpu.as_entire_binding(),
1188 },
1189 BindGroupEntry {
1190 binding: 4,
1191 resource: positions_gpu.as_entire_binding(),
1192 },
1193 BindGroupEntry {
1194 binding: 5,
1195 resource: metadatas_gpu.as_entire_binding(),
1196 },
1197 ],
1198 });
1199
1200 (bind_group, compute_pipeline)
1201}
1202
1203fn bind_group_and_pipeline_find_matches_and_compute_f7(
1204 device: &wgpu::Device,
1205 module: &ShaderModule,
1206 parent_buckets_gpu: &Buffer,
1207 parent_metadatas_gpu: &Buffer,
1208 table_6_proof_targets_sizes_gpu: &Buffer,
1209 table_6_proof_targets_gpu: &Buffer,
1210) -> (BindGroup, ComputePipeline) {
1211 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1212 label: Some("find_matches_and_compute_f7"),
1213 entries: &[
1214 BindGroupLayoutEntry {
1215 binding: 0,
1216 count: None,
1217 visibility: ShaderStages::COMPUTE,
1218 ty: BindingType::Buffer {
1219 has_dynamic_offset: false,
1220 min_binding_size: None,
1221 ty: BufferBindingType::Storage { read_only: true },
1222 },
1223 },
1224 BindGroupLayoutEntry {
1225 binding: 1,
1226 count: None,
1227 visibility: ShaderStages::COMPUTE,
1228 ty: BindingType::Buffer {
1229 has_dynamic_offset: false,
1230 min_binding_size: None,
1231 ty: BufferBindingType::Storage { read_only: true },
1232 },
1233 },
1234 BindGroupLayoutEntry {
1235 binding: 2,
1236 count: None,
1237 visibility: ShaderStages::COMPUTE,
1238 ty: BindingType::Buffer {
1239 has_dynamic_offset: false,
1240 min_binding_size: None,
1241 ty: BufferBindingType::Storage { read_only: false },
1242 },
1243 },
1244 BindGroupLayoutEntry {
1245 binding: 3,
1246 count: None,
1247 visibility: ShaderStages::COMPUTE,
1248 ty: BindingType::Buffer {
1249 has_dynamic_offset: false,
1250 min_binding_size: None,
1251 ty: BufferBindingType::Storage { read_only: false },
1252 },
1253 },
1254 ],
1255 });
1256
1257 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1258 label: Some("find_matches_and_compute_f7"),
1259 bind_group_layouts: &[Some(&bind_group_layout)],
1260 immediate_size: 0,
1261 });
1262
1263 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1264 compilation_options: PipelineCompilationOptions {
1265 constants: &[],
1266 zero_initialize_workgroup_memory: true,
1267 },
1268 cache: None,
1269 label: Some("find_matches_and_compute_f7"),
1270 layout: Some(&pipeline_layout),
1271 module,
1272 entry_point: Some("find_matches_and_compute_f7"),
1273 });
1274
1275 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1276 label: Some("find_matches_and_compute_f7"),
1277 layout: &bind_group_layout,
1278 entries: &[
1279 BindGroupEntry {
1280 binding: 0,
1281 resource: parent_buckets_gpu.as_entire_binding(),
1282 },
1283 BindGroupEntry {
1284 binding: 1,
1285 resource: parent_metadatas_gpu.as_entire_binding(),
1286 },
1287 BindGroupEntry {
1288 binding: 2,
1289 resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1290 },
1291 BindGroupEntry {
1292 binding: 3,
1293 resource: table_6_proof_targets_gpu.as_entire_binding(),
1294 },
1295 ],
1296 });
1297
1298 (bind_group, compute_pipeline)
1299}
1300
1301#[expect(
1302 clippy::too_many_arguments,
1303 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1304)]
1305fn bind_group_and_pipeline_find_proofs(
1306 device: &wgpu::Device,
1307 module: &ShaderModule,
1308 table_2_positions_gpu: &Buffer,
1309 table_3_positions_gpu: &Buffer,
1310 table_4_positions_gpu: &Buffer,
1311 table_5_positions_gpu: &Buffer,
1312 table_6_positions_gpu: &Buffer,
1313 bucket_sizes_gpu: &Buffer,
1314 buckets_gpu: &Buffer,
1315 proofs_gpu: &Buffer,
1316) -> (BindGroup, ComputePipeline) {
1317 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1318 label: Some("find_proofs"),
1319 entries: &[
1320 BindGroupLayoutEntry {
1321 binding: 0,
1322 count: None,
1323 visibility: ShaderStages::COMPUTE,
1324 ty: BindingType::Buffer {
1325 has_dynamic_offset: false,
1326 min_binding_size: None,
1327 ty: BufferBindingType::Storage { read_only: true },
1328 },
1329 },
1330 BindGroupLayoutEntry {
1331 binding: 1,
1332 count: None,
1333 visibility: ShaderStages::COMPUTE,
1334 ty: BindingType::Buffer {
1335 has_dynamic_offset: false,
1336 min_binding_size: None,
1337 ty: BufferBindingType::Storage { read_only: true },
1338 },
1339 },
1340 BindGroupLayoutEntry {
1341 binding: 2,
1342 count: None,
1343 visibility: ShaderStages::COMPUTE,
1344 ty: BindingType::Buffer {
1345 has_dynamic_offset: false,
1346 min_binding_size: None,
1347 ty: BufferBindingType::Storage { read_only: true },
1348 },
1349 },
1350 BindGroupLayoutEntry {
1351 binding: 3,
1352 count: None,
1353 visibility: ShaderStages::COMPUTE,
1354 ty: BindingType::Buffer {
1355 has_dynamic_offset: false,
1356 min_binding_size: None,
1357 ty: BufferBindingType::Storage { read_only: true },
1358 },
1359 },
1360 BindGroupLayoutEntry {
1361 binding: 4,
1362 count: None,
1363 visibility: ShaderStages::COMPUTE,
1364 ty: BindingType::Buffer {
1365 has_dynamic_offset: false,
1366 min_binding_size: None,
1367 ty: BufferBindingType::Storage { read_only: true },
1368 },
1369 },
1370 BindGroupLayoutEntry {
1371 binding: 5,
1372 count: None,
1373 visibility: ShaderStages::COMPUTE,
1374 ty: BindingType::Buffer {
1375 has_dynamic_offset: false,
1376 min_binding_size: None,
1377 ty: BufferBindingType::Storage { read_only: false },
1378 },
1379 },
1380 BindGroupLayoutEntry {
1381 binding: 6,
1382 count: None,
1383 visibility: ShaderStages::COMPUTE,
1384 ty: BindingType::Buffer {
1385 has_dynamic_offset: false,
1386 min_binding_size: None,
1387 ty: BufferBindingType::Storage { read_only: true },
1388 },
1389 },
1390 BindGroupLayoutEntry {
1391 binding: 7,
1392 count: None,
1393 visibility: ShaderStages::COMPUTE,
1394 ty: BindingType::Buffer {
1395 has_dynamic_offset: false,
1396 min_binding_size: None,
1397 ty: BufferBindingType::Storage { read_only: false },
1398 },
1399 },
1400 ],
1401 });
1402
1403 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1404 label: Some("find_proofs"),
1405 bind_group_layouts: &[Some(&bind_group_layout)],
1406 immediate_size: 0,
1407 });
1408
1409 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1410 compilation_options: PipelineCompilationOptions {
1411 constants: &[],
1412 zero_initialize_workgroup_memory: false,
1413 },
1414 cache: None,
1415 label: Some("find_proofs"),
1416 layout: Some(&pipeline_layout),
1417 module,
1418 entry_point: Some("find_proofs"),
1419 });
1420
1421 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1422 label: Some("find_proofs"),
1423 layout: &bind_group_layout,
1424 entries: &[
1425 BindGroupEntry {
1426 binding: 0,
1427 resource: table_2_positions_gpu.as_entire_binding(),
1428 },
1429 BindGroupEntry {
1430 binding: 1,
1431 resource: table_3_positions_gpu.as_entire_binding(),
1432 },
1433 BindGroupEntry {
1434 binding: 2,
1435 resource: table_4_positions_gpu.as_entire_binding(),
1436 },
1437 BindGroupEntry {
1438 binding: 3,
1439 resource: table_5_positions_gpu.as_entire_binding(),
1440 },
1441 BindGroupEntry {
1442 binding: 4,
1443 resource: table_6_positions_gpu.as_entire_binding(),
1444 },
1445 BindGroupEntry {
1446 binding: 5,
1447 resource: bucket_sizes_gpu.as_entire_binding(),
1448 },
1449 BindGroupEntry {
1450 binding: 6,
1451 resource: buckets_gpu.as_entire_binding(),
1452 },
1453 BindGroupEntry {
1454 binding: 7,
1455 resource: proofs_gpu.as_entire_binding(),
1456 },
1457 ],
1458 });
1459
1460 (bind_group, compute_pipeline)
1461}