1pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use ab_core_primitives::ed25519::Ed25519PublicKey;
10use ab_core_primitives::sectors::SectorIndex;
11use ab_core_primitives::solutions::ShardCommitmentHash;
12use ab_data_retrieval::piece_getter::PieceGetter;
13use ab_erasure_coding::ErasureCoding;
14use ab_farmer_components::FarmerProtocolInfo;
15use ab_farmer_components::plotting::{
16 CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector,
17 encode_sector, write_sector,
18};
19use ab_proof_of_space::Table;
20use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
21use async_trait::async_trait;
22use bytes::Bytes;
23use event_listener_primitives::{Bag, HandlerId};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
27use prometheus_client::registry::Registry;
28use std::any::type_name;
29use std::error::Error;
30use std::fmt;
31use std::future::pending;
32use std::marker::PhantomData;
33use std::num::NonZeroUsize;
34use std::pin::pin;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicBool, Ordering};
37use std::task::Poll;
38use std::time::Instant;
39use tokio::task::yield_now;
40use tracing::{Instrument, warn};
41
42pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
44type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
45
46#[derive(Default, Debug)]
47struct Handlers {
48 plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
49}
50
51pub struct CpuPlotter<PG, PosTable> {
53 piece_getter: PG,
54 downloading_semaphore: Arc<Semaphore>,
55 plotting_thread_pool_manager: PlottingThreadPoolManager,
56 record_encoding_concurrency: NonZeroUsize,
57 global_mutex: Arc<AsyncMutex<()>>,
58 erasure_coding: ErasureCoding,
59 handlers: Arc<Handlers>,
60 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
61 _background_tasks: AsyncJoinOnDrop<()>,
62 abort_early: Arc<AtomicBool>,
63 metrics: Option<Arc<CpuPlotterMetrics>>,
64 _phantom: PhantomData<PosTable>,
65}
66
67impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
68 #[inline]
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("CpuPlotter").finish_non_exhaustive()
71 }
72}
73
74impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
75 #[inline]
76 fn drop(&mut self) {
77 self.abort_early.store(true, Ordering::Release);
78 self.tasks_sender.close_channel();
79 }
80}
81
82#[async_trait]
83impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
84where
85 PG: PieceGetter + Clone + Send + Sync + 'static,
86 PosTable: Table,
87{
88 async fn has_free_capacity(&self) -> Result<bool, String> {
89 Ok(self.downloading_semaphore.try_acquire().is_some())
90 }
91
92 async fn plot_sector(
93 &self,
94 public_key: Ed25519PublicKey,
95 shard_commitments_root: ShardCommitmentHash,
96 sector_index: SectorIndex,
97 farmer_protocol_info: FarmerProtocolInfo,
98 pieces_in_sector: u16,
99 replotting: bool,
100 progress_sender: mpsc::Sender<SectorPlottingProgress>,
101 ) {
102 let start = Instant::now();
103
104 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
107
108 self.plot_sector_internal(
109 start,
110 downloading_permit,
111 public_key,
112 shard_commitments_root,
113 sector_index,
114 farmer_protocol_info,
115 pieces_in_sector,
116 replotting,
117 progress_sender,
118 )
119 .await
120 }
121
122 async fn try_plot_sector(
123 &self,
124 public_key: Ed25519PublicKey,
125 shard_commitments_root: ShardCommitmentHash,
126 sector_index: SectorIndex,
127 farmer_protocol_info: FarmerProtocolInfo,
128 pieces_in_sector: u16,
129 replotting: bool,
130 progress_sender: mpsc::Sender<SectorPlottingProgress>,
131 ) -> bool {
132 let start = Instant::now();
133
134 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
135 return false;
136 };
137
138 self.plot_sector_internal(
139 start,
140 downloading_permit,
141 public_key,
142 shard_commitments_root,
143 sector_index,
144 farmer_protocol_info,
145 pieces_in_sector,
146 replotting,
147 progress_sender,
148 )
149 .await;
150
151 true
152 }
153}
154
155impl<PG, PosTable> CpuPlotter<PG, PosTable>
156where
157 PG: PieceGetter + Clone + Send + Sync + 'static,
158 PosTable: Table,
159{
160 pub fn new(
162 piece_getter: PG,
163 downloading_semaphore: Arc<Semaphore>,
164 plotting_thread_pool_manager: PlottingThreadPoolManager,
165 record_encoding_concurrency: NonZeroUsize,
166 global_mutex: Arc<AsyncMutex<()>>,
167 erasure_coding: ErasureCoding,
168 registry: Option<&mut Registry>,
169 ) -> Self {
170 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
171
172 let background_tasks = AsyncJoinOnDrop::new(
174 tokio::spawn(async move {
175 let background_tasks = FuturesUnordered::new();
176 let mut background_tasks = pin!(background_tasks);
177 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
179
180 loop {
181 select! {
182 maybe_background_task = tasks_receiver.next().fuse() => {
183 let Some(background_task) = maybe_background_task else {
184 break;
185 };
186
187 background_tasks.push(background_task);
188 },
189 _ = background_tasks.select_next_some() => {
190 }
192 }
193 }
194 }),
195 true,
196 );
197
198 let abort_early = Arc::new(AtomicBool::new(false));
199 let metrics = registry.map(|registry| {
200 Arc::new(CpuPlotterMetrics::new(
201 registry,
202 type_name::<PosTable>(),
203 plotting_thread_pool_manager.thread_pool_pairs(),
204 ))
205 });
206
207 Self {
208 piece_getter,
209 downloading_semaphore,
210 plotting_thread_pool_manager,
211 record_encoding_concurrency,
212 global_mutex,
213 erasure_coding,
214 handlers: Arc::default(),
215 tasks_sender,
216 _background_tasks: background_tasks,
217 abort_early,
218 metrics,
219 _phantom: PhantomData,
220 }
221 }
222
223 pub fn on_plotting_progress(
225 &self,
226 callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
227 ) -> HandlerId {
228 self.handlers.plotting_progress.add(callback)
229 }
230
231 #[expect(clippy::too_many_arguments)]
232 async fn plot_sector_internal<PS>(
233 &self,
234 start: Instant,
235 downloading_permit: SemaphoreGuardArc,
236 public_key: Ed25519PublicKey,
237 shard_commitments_root: ShardCommitmentHash,
238 sector_index: SectorIndex,
239 farmer_protocol_info: FarmerProtocolInfo,
240 pieces_in_sector: u16,
241 replotting: bool,
242 mut progress_sender: PS,
243 ) where
244 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
245 PS::Error: Error,
246 {
247 if let Some(metrics) = &self.metrics {
248 metrics.sector_plotting.inc();
249 }
250
251 let progress_updater = ProgressUpdater {
252 public_key,
253 sector_index,
254 handlers: Arc::clone(&self.handlers),
255 metrics: self.metrics.clone(),
256 };
257
258 let plotting_fut = {
259 let piece_getter = self.piece_getter.clone();
260 let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
261 let record_encoding_concurrency = self.record_encoding_concurrency;
262 let global_mutex = Arc::clone(&self.global_mutex);
263 let erasure_coding = self.erasure_coding.clone();
264 let abort_early = Arc::clone(&self.abort_early);
265 let metrics = self.metrics.clone();
266
267 async move {
268 let downloaded_sector = {
270 if !progress_updater
271 .update_progress_and_events(
272 &mut progress_sender,
273 SectorPlottingProgress::Downloading,
274 )
275 .await
276 {
277 return;
278 }
279
280 global_mutex.lock().await;
282
283 let downloading_start = Instant::now();
284 let public_key_hash = &public_key.hash();
285
286 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
287 public_key_hash,
288 shard_commitments_root: &shard_commitments_root,
289 sector_index,
290 piece_getter: &piece_getter,
291 farmer_protocol_info,
292 erasure_coding: &erasure_coding,
293 pieces_in_sector,
294 });
295
296 let downloaded_sector = match downloaded_sector_fut.await {
297 Ok(downloaded_sector) => downloaded_sector,
298 Err(error) => {
299 warn!(%error, "Failed to download sector");
300
301 progress_updater
302 .update_progress_and_events(
303 &mut progress_sender,
304 SectorPlottingProgress::Error {
305 error: format!("Failed to download sector: {error}"),
306 },
307 )
308 .await;
309
310 return;
311 }
312 };
313
314 if !progress_updater
315 .update_progress_and_events(
316 &mut progress_sender,
317 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
318 )
319 .await
320 {
321 return;
322 }
323
324 downloaded_sector
325 };
326
327 let (sector, plotted_sector) = {
329 let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
330 if let Some(metrics) = &metrics {
331 metrics.plotting_capacity_used.inc();
332 }
333
334 yield_now().await;
336
337 if !progress_updater
338 .update_progress_and_events(
339 &mut progress_sender,
340 SectorPlottingProgress::Encoding,
341 )
342 .await
343 {
344 if let Some(metrics) = &metrics {
345 metrics.plotting_capacity_used.dec();
346 }
347 return;
348 }
349
350 let encoding_start = Instant::now();
351
352 let plotting_result = tokio::task::block_in_place(move || {
353 let thread_pool = if replotting {
354 &thread_pools.replotting
355 } else {
356 &thread_pools.plotting
357 };
358
359 let encoded_sector = thread_pool.install(|| {
360 let generator = PosTable::generator();
362 let generators = (0..record_encoding_concurrency.get())
363 .map(|_| generator.clone())
364 .collect::<Vec<_>>();
365 let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
366 &generators,
367 &erasure_coding,
368 &global_mutex,
369 );
370
371 encode_sector(
372 downloaded_sector,
373 EncodeSectorOptions {
374 sector_index,
375 records_encoder: &mut records_encoder,
376 abort_early: &abort_early,
377 },
378 )
379 })?;
380
381 if abort_early.load(Ordering::Acquire) {
382 return Err(PlottingError::AbortEarly);
383 }
384
385 drop(thread_pools);
386
387 let mut sector = Vec::new();
388
389 write_sector(&encoded_sector, &mut sector)?;
390
391 Ok((sector, encoded_sector.plotted_sector))
392 });
393
394 if let Some(metrics) = &metrics {
395 metrics.plotting_capacity_used.dec();
396 }
397
398 match plotting_result {
399 Ok(plotting_result) => {
400 if !progress_updater
401 .update_progress_and_events(
402 &mut progress_sender,
403 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
404 )
405 .await
406 {
407 return;
408 }
409
410 plotting_result
411 }
412 Err(PlottingError::AbortEarly) => {
413 return;
414 }
415 Err(error) => {
416 progress_updater
417 .update_progress_and_events(
418 &mut progress_sender,
419 SectorPlottingProgress::Error {
420 error: format!("Failed to encode sector: {error}"),
421 },
422 )
423 .await;
424
425 return;
426 }
427 }
428 };
429
430 progress_updater
431 .update_progress_and_events(
432 &mut progress_sender,
433 SectorPlottingProgress::Finished {
434 plotted_sector,
435 time: start.elapsed(),
436 sector: Box::pin({
437 let mut sector = Some(Ok(Bytes::from(sector)));
438
439 stream::poll_fn(move |_cx| {
440 let _downloading_permit = &downloading_permit;
442
443 Poll::Ready(sector.take())
444 })
445 }),
446 },
447 )
448 .await;
449 }
450 };
451
452 let plotting_task =
454 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
455 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
456 warn!(%error, "Failed to send plotting task");
457
458 let progress = SectorPlottingProgress::Error {
459 error: format!("Failed to send plotting task: {error}"),
460 };
461
462 self.handlers
463 .plotting_progress
464 .call_simple(&public_key, §or_index, &progress);
465 }
466 }
467}
468
469struct ProgressUpdater {
470 public_key: Ed25519PublicKey,
471 sector_index: SectorIndex,
472 handlers: Arc<Handlers>,
473 metrics: Option<Arc<CpuPlotterMetrics>>,
474}
475
476impl ProgressUpdater {
477 async fn update_progress_and_events<PS>(
479 &self,
480 progress_sender: &mut PS,
481 progress: SectorPlottingProgress,
482 ) -> bool
483 where
484 PS: Sink<SectorPlottingProgress> + Unpin,
485 PS::Error: Error,
486 {
487 if let Some(metrics) = &self.metrics {
488 match &progress {
489 SectorPlottingProgress::Downloading => {
490 metrics.sector_downloading.inc();
491 }
492 SectorPlottingProgress::Downloaded(time) => {
493 metrics.sector_downloading_time.observe(time.as_secs_f64());
494 metrics.sector_downloaded.inc();
495 }
496 SectorPlottingProgress::Encoding => {
497 metrics.sector_encoding.inc();
498 }
499 SectorPlottingProgress::Encoded(time) => {
500 metrics.sector_encoding_time.observe(time.as_secs_f64());
501 metrics.sector_encoded.inc();
502 }
503 SectorPlottingProgress::Finished { time, .. } => {
504 metrics.sector_plotting_time.observe(time.as_secs_f64());
505 metrics.sector_plotted.inc();
506 }
507 SectorPlottingProgress::Error { .. } => {
508 metrics.sector_plotting_error.inc();
509 }
510 }
511 }
512 self.handlers.plotting_progress.call_simple(
513 &self.public_key,
514 &self.sector_index,
515 &progress,
516 );
517
518 if let Err(error) = progress_sender.send(progress).await {
519 warn!(%error, "Failed to send progress update");
520
521 false
522 } else {
523 true
524 }
525 }
526}