Skip to main content

ab_farmer/cluster/
nats_client.rs

1//! NATS client
2//!
3//! [`NatsClient`] provided here is a wrapper around [`Client`] that provides convenient methods
4//! using domain-specific traits.
5//!
6//! Before reading code, make sure to familiarize yourself with NATS documentation, especially with
7//! [subjects](https://docs.nats.io/nats-concepts/subjects) and
8//! [Core NATS](https://docs.nats.io/nats-concepts/core-nats) features.
9//!
10//! Abstractions provided here cover a few use cases:
11//! * request/response (for example piece request)
12//! * request/stream of responses (for example a stream of plotted sectors of the farmer)
13//! * notifications (typically targeting a particular instance of an app) and corresponding
14//!   subscriptions (for example solution notification)
15//! * broadcasts and corresponding subscriptions (for example slot info broadcast)
16
17use crate::utils::AsyncJoinOnDrop;
18use anyhow::anyhow;
19use async_nats::{
20    Client, ConnectOptions, HeaderMap, HeaderValue, Message, PublishError, RequestError,
21    RequestErrorKind, Subject, SubscribeError, Subscriber, ToServerAddrs,
22};
23use backon::{BackoffBuilder, ExponentialBuilder};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Stream, StreamExt, select};
27use parity_scale_codec::{Decode, Encode};
28use std::any::type_name;
29use std::collections::VecDeque;
30use std::fmt;
31use std::future::Future;
32use std::marker::PhantomData;
33use std::ops::Deref;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37use std::time::Duration;
38use thiserror::Error;
39use tracing::{Instrument, debug, error, trace, warn};
40use ulid::Ulid;
41
42const EXPECTED_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
43const ACKNOWLEDGEMENT_TIMEOUT: Duration = Duration::from_mins(1);
44/// Requests should time out eventually, but we should set a larger timeout to allow for spikes in
45/// load to be absorbed gracefully
46const REQUEST_TIMEOUT: Duration = Duration::from_mins(5);
47
48/// Generic request with associated response.
49///
50/// Used for cases where request/response pattern is needed and response contains a single small
51/// message. For large messages or multiple messages chunking with [`GenericStreamRequest`] can be
52/// used instead.
53pub trait GenericRequest: Encode + Decode + fmt::Debug + Send + Sync + 'static {
54    /// Request subject with optional `*` in place of application instance to receive the request
55    const SUBJECT: &'static str;
56    /// Response type that corresponds to this request
57    type Response: Encode + Decode + fmt::Debug + Send + Sync + 'static;
58}
59
60/// Generic stream request where response is streamed using
61/// [`NatsClient::stream_request_responder`].
62///
63/// Used for cases where a large payload that doesn't fit into NATS message needs to be sent or
64/// there is a very large number of messages to send. For simple request/response patten
65/// [`GenericRequest`] can be used instead.
66pub trait GenericStreamRequest: Encode + Decode + fmt::Debug + Send + Sync + 'static {
67    /// Request subject with optional `*` in place of application instance to receive the request
68    const SUBJECT: &'static str;
69    /// Response type that corresponds to this stream request.
70    ///
71    /// These responses are send as a stream of messages, each message must fit into NATS message,
72    /// [`NatsClient::approximate_max_message_size()`] can be used to estimate appropriate message
73    /// size in case chunking is needed.
74    type Response: Encode + Decode + fmt::Debug + Send + Sync + 'static;
75}
76
77/// Messages sent in response to [`GenericStreamRequest`].
78///
79/// Empty list of responses means the end of the stream.
80#[derive(Debug, Encode, Decode)]
81enum GenericStreamResponses<Response> {
82    /// Some responses, but the stream didn't end yet
83    Continue {
84        /// Monotonically increasing index of responses in a stream
85        index: u32,
86        /// Individual responses
87        responses: VecDeque<Response>,
88        /// Subject where to send acknowledgement of received stream response indices, which acts
89        /// as a backpressure mechanism
90        ack_subject: String,
91    },
92    /// Remaining responses and this is the end of the stream.
93    Last {
94        /// Monotonically increasing index of responses in a stream
95        index: u32,
96        /// Individual responses
97        responses: VecDeque<Response>,
98    },
99}
100
101impl<Response> From<GenericStreamResponses<Response>> for VecDeque<Response> {
102    #[inline]
103    fn from(value: GenericStreamResponses<Response>) -> Self {
104        match value {
105            GenericStreamResponses::Continue { responses, .. } => responses,
106            GenericStreamResponses::Last { responses, .. } => responses,
107        }
108    }
109}
110
111impl<Response> GenericStreamResponses<Response> {
112    fn next(&mut self) -> Option<Response> {
113        match self {
114            GenericStreamResponses::Continue { responses, .. } => responses.pop_front(),
115            GenericStreamResponses::Last { responses, .. } => responses.pop_front(),
116        }
117    }
118
119    fn index(&self) -> u32 {
120        match self {
121            GenericStreamResponses::Continue { index, .. } => *index,
122            GenericStreamResponses::Last { index, .. } => *index,
123        }
124    }
125
126    fn ack_subject(&self) -> Option<&str> {
127        if let GenericStreamResponses::Continue { ack_subject, .. } = self {
128            Some(ack_subject)
129        } else {
130            None
131        }
132    }
133
134    fn is_last(&self) -> bool {
135        matches!(self, Self::Last { .. })
136    }
137}
138
139/// Stream request error
140#[derive(Debug, Error)]
141pub enum StreamRequestError {
142    /// Subscribe error
143    #[error("Subscribe error: {0}")]
144    Subscribe(#[from] SubscribeError),
145    /// Publish error
146    #[error("Publish error: {0}")]
147    Publish(#[from] PublishError),
148}
149
150/// Wrapper around subscription that transforms stream of wrapped response messages into a normal
151/// `Response` stream.
152#[derive(Debug)]
153#[pin_project::pin_project]
154pub struct StreamResponseSubscriber<Response> {
155    #[pin]
156    subscriber: Subscriber,
157    response_subject: String,
158    buffered_responses: Option<GenericStreamResponses<Response>>,
159    next_index: u32,
160    acknowledgement_sender: mpsc::UnboundedSender<(String, u32)>,
161    _background_task: AsyncJoinOnDrop<()>,
162    _phantom: PhantomData<Response>,
163}
164
165impl<Response> Stream for StreamResponseSubscriber<Response>
166where
167    Response: Decode,
168{
169    type Item = Response;
170
171    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172        if let Some(buffered_responses) = self.buffered_responses.as_mut() {
173            if let Some(response) = buffered_responses.next() {
174                return Poll::Ready(Some(response));
175            } else if buffered_responses.is_last() {
176                return Poll::Ready(None);
177            }
178
179            self.buffered_responses.take();
180            self.next_index += 1;
181        }
182
183        let mut projected = self.project();
184        match projected.subscriber.poll_next_unpin(cx) {
185            Poll::Ready(Some(message)) => {
186                match GenericStreamResponses::<Response>::decode(&mut message.payload.as_ref()) {
187                    Ok(mut responses) => {
188                        if responses.index() != *projected.next_index {
189                            warn!(
190                                actual_index = %responses.index(),
191                                expected_index = %*projected.next_index,
192                                message_type = %type_name::<Response>(),
193                                response_subject = %projected.response_subject,
194                                "Received unexpected response stream index, aborting stream"
195                            );
196
197                            return Poll::Ready(None);
198                        }
199
200                        if let Some(ack_subject) = responses.ack_subject() {
201                            let index = responses.index();
202                            let ack_subject = ack_subject.to_string();
203
204                            if let Err(error) = projected
205                                .acknowledgement_sender
206                                .unbounded_send((ack_subject.clone(), index))
207                            {
208                                warn!(
209                                    %error,
210                                    %index,
211                                    message_type = %type_name::<Response>(),
212                                    response_subject = %projected.response_subject,
213                                    %ack_subject,
214                                    "Failed to send acknowledgement for stream response"
215                                );
216                            }
217                        }
218
219                        if let Some(response) = responses.next() {
220                            *projected.buffered_responses = Some(responses);
221                            Poll::Ready(Some(response))
222                        } else {
223                            Poll::Ready(None)
224                        }
225                    }
226                    Err(error) => {
227                        warn!(
228                            %error,
229                            response_type = %type_name::<Response>(),
230                            response_subject = %projected.response_subject,
231                            message = %hex::encode(message.payload),
232                            "Failed to decode stream response"
233                        );
234
235                        Poll::Ready(None)
236                    }
237                }
238            }
239            Poll::Ready(None) => Poll::Ready(None),
240            Poll::Pending => Poll::Pending,
241        }
242    }
243}
244
245impl<Response> StreamResponseSubscriber<Response> {
246    fn new(subscriber: Subscriber, response_subject: String, nats_client: NatsClient) -> Self {
247        let (acknowledgement_sender, mut acknowledgement_receiver) =
248            mpsc::unbounded::<(String, u32)>();
249
250        let ack_publisher_fut = {
251            let response_subject = response_subject.clone();
252
253            async move {
254                while let Some((subject, index)) = acknowledgement_receiver.next().await {
255                    trace!(
256                        %subject,
257                        %index,
258                        %response_subject,
259                        %index,
260                        "Sending stream response acknowledgement"
261                    );
262                    if let Err(error) = nats_client
263                        .publish(subject.clone(), index.to_le_bytes().to_vec().into())
264                        .await
265                    {
266                        warn!(
267                            %error,
268                            %subject,
269                            %index,
270                            %response_subject,
271                            %index,
272                            "Failed to send stream response acknowledgement"
273                        );
274                        return;
275                    }
276                }
277            }
278        };
279        let background_task =
280            AsyncJoinOnDrop::new(tokio::spawn(ack_publisher_fut.in_current_span()), true);
281
282        Self {
283            response_subject,
284            subscriber,
285            buffered_responses: None,
286            next_index: 0,
287            acknowledgement_sender,
288            _background_task: background_task,
289            _phantom: PhantomData,
290        }
291    }
292}
293
294/// Generic one-off notification
295pub trait GenericNotification: Encode + Decode + fmt::Debug + Send + Sync + 'static {
296    /// Notification subject with optional `*` in place of application instance receiving the
297    /// request
298    const SUBJECT: &'static str;
299}
300
301/// Generic broadcast message.
302///
303/// Broadcast messages are sent by an instance to (potentially) an instance-specific subject that
304/// any other app can subscribe to. The same broadcast message can also originate from multiple
305/// places and be de-duplicated using [`Self::deterministic_message_id`].
306pub trait GenericBroadcast: Encode + Decode + fmt::Debug + Send + Sync + 'static {
307    /// Broadcast subject with optional `*` in place of application instance sending broadcast
308    const SUBJECT: &'static str;
309
310    /// Deterministic message ID that is used for de-duplicating messages broadcast by different
311    /// instances
312    fn deterministic_message_id(&self) -> Option<HeaderValue> {
313        None
314    }
315}
316
317/// Subscriber wrapper that decodes messages automatically and skips messages that can't be decoded
318#[derive(Debug)]
319#[pin_project::pin_project]
320pub struct SubscriberWrapper<Message> {
321    #[pin]
322    subscriber: Subscriber,
323    _phantom: PhantomData<Message>,
324}
325
326impl<Message> Stream for SubscriberWrapper<Message>
327where
328    Message: Decode,
329{
330    type Item = Message;
331
332    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
333        match self.project().subscriber.poll_next_unpin(cx) {
334            Poll::Ready(Some(message)) => match Message::decode(&mut message.payload.as_ref()) {
335                Ok(message) => Poll::Ready(Some(message)),
336                Err(error) => {
337                    warn!(
338                        %error,
339                        message_type = %type_name::<Message>(),
340                        message = %hex::encode(message.payload),
341                        "Failed to decode stream message"
342                    );
343
344                    Poll::Pending
345                }
346            },
347            Poll::Ready(None) => Poll::Ready(None),
348            Poll::Pending => Poll::Pending,
349        }
350    }
351}
352
353#[derive(Debug)]
354struct Inner {
355    client: Client,
356    request_retry_backoff_policy: ExponentialBuilder,
357    approximate_max_message_size: usize,
358    max_message_size: usize,
359}
360
361/// NATS client wrapper that can be used to interact with other cluster-specific clients
362#[derive(Debug, Clone)]
363pub struct NatsClient {
364    inner: Arc<Inner>,
365}
366
367impl Deref for NatsClient {
368    type Target = Client;
369
370    #[inline]
371    fn deref(&self) -> &Self::Target {
372        &self.inner.client
373    }
374}
375
376impl NatsClient {
377    /// Create a new instance by connecting to specified addresses
378    pub async fn new<A: ToServerAddrs>(
379        addrs: A,
380        request_retry_backoff_policy: ExponentialBuilder,
381    ) -> Result<Self, async_nats::Error> {
382        let servers = addrs.to_server_addrs()?.collect::<Vec<_>>();
383        Self::from_client(
384            async_nats::connect_with_options(
385                &servers,
386                ConnectOptions::default().request_timeout(Some(REQUEST_TIMEOUT)),
387            )
388            .await?,
389            request_retry_backoff_policy,
390        )
391    }
392
393    /// Create new client from existing NATS instance
394    pub fn from_client(
395        client: Client,
396        request_retry_backoff_policy: ExponentialBuilder,
397    ) -> Result<Self, async_nats::Error> {
398        let max_payload = client.server_info().max_payload;
399        if max_payload < EXPECTED_MESSAGE_SIZE {
400            return Err(format!(
401                "Max payload {max_payload} is smaller than expected {EXPECTED_MESSAGE_SIZE}, \
402                increase it by specifying max_payload = 2MB or higher number in NATS configuration"
403            )
404            .into());
405        }
406
407        let inner = Inner {
408            client,
409            request_retry_backoff_policy,
410            // Allow up to 90%, the rest will be wrapper data structures, etc.
411            approximate_max_message_size: max_payload * 9 / 10,
412            // Allow up to 90%, the rest will be wrapper data structures, etc.
413            max_message_size: max_payload,
414        };
415
416        Ok(Self {
417            inner: Arc::new(inner),
418        })
419    }
420
421    /// Approximate max message size (a few more bytes will not hurt), the actual limit is expected
422    /// to be a bit higher
423    pub fn approximate_max_message_size(&self) -> usize {
424        self.inner.approximate_max_message_size
425    }
426
427    /// Make request and wait for response
428    pub async fn request<Request>(
429        &self,
430        request: &Request,
431        instance: Option<&str>,
432    ) -> Result<Request::Response, RequestError>
433    where
434        Request: GenericRequest,
435    {
436        let subject = subject_with_instance(Request::SUBJECT, instance);
437        let mut maybe_retry_backoff = None;
438        let message = loop {
439            match self
440                .inner
441                .client
442                .request(subject.clone(), request.encode().into())
443                .await
444            {
445                Ok(message) => {
446                    break message;
447                }
448                Err(error) => {
449                    match error.kind() {
450                        RequestErrorKind::TimedOut | RequestErrorKind::NoResponders => {
451                            // Continue with retries
452                        }
453                        RequestErrorKind::InvalidSubject | RequestErrorKind::Other => {
454                            return Err(error);
455                        }
456                    }
457
458                    let retry_backoff = maybe_retry_backoff
459                        .get_or_insert_with(|| self.inner.request_retry_backoff_policy.build());
460
461                    if let Some(delay) = retry_backoff.next() {
462                        debug!(
463                            %subject,
464                            %error,
465                            request_type = %type_name::<Request>(),
466                            ?delay,
467                            "Failed to make request, retrying after some delay"
468                        );
469
470                        tokio::time::sleep(delay).await;
471                        continue;
472                    } else {
473                        return Err(error);
474                    }
475                }
476            }
477        };
478
479        let response =
480            Request::Response::decode(&mut message.payload.as_ref()).map_err(|error| {
481                warn!(
482                    %subject,
483                    %error,
484                    response_type = %type_name::<Request::Response>(),
485                    response = %hex::encode(message.payload),
486                    "Response decoding failed"
487                );
488
489                RequestErrorKind::Other
490            })?;
491
492        Ok(response)
493    }
494
495    /// Responds to requests from the given subject using the provided processing function.
496    ///
497    /// This will create a subscription on the subject for the given instance (if provided) and
498    /// queue group. Incoming messages will be deserialized as the request type `Request` and passed
499    /// to the `process` function to produce a response of type `Request::Response`. The response
500    /// will then be sent back on the reply subject from the original request.
501    ///
502    /// Each request is processed in a newly created async tokio task.
503    ///
504    /// # Arguments
505    ///
506    /// * `instance` - Optional instance name to use in place of the `*` in the subject
507    /// * `group` - The queue group name for the subscription
508    /// * `process` - The function to call with the decoded request to produce a response
509    pub async fn request_responder<Request, F, OP>(
510        &self,
511        instance: Option<&str>,
512        queue_group: Option<String>,
513        process: OP,
514    ) -> anyhow::Result<()>
515    where
516        Request: GenericRequest,
517        F: Future<Output = Option<Request::Response>> + Send,
518        OP: Fn(Request) -> F + Send + Sync,
519    {
520        // Initialize with pending future so it never ends
521        let mut processing = FuturesUnordered::new();
522
523        let subscription = self
524            .common_subscribe(Request::SUBJECT, instance, queue_group)
525            .await
526            .map_err(|error| {
527                anyhow!(
528                    "Failed to subscribe to {} requests for {instance:?}: {error}",
529                    type_name::<Request>(),
530                )
531            })?;
532
533        debug!(
534            request_type = %type_name::<Request>(),
535            ?subscription,
536            "Requests subscription"
537        );
538        let mut subscription = subscription.fuse();
539
540        loop {
541            select! {
542                message = subscription.select_next_some() => {
543                    // Create background task for concurrent processing
544                    processing.push(
545                        self
546                            .process_request(
547                                message,
548                                &process,
549                            )
550                            .in_current_span(),
551                    );
552                },
553                _ = processing.next() => {
554                    // Nothing to do here
555                },
556                complete => {
557                    break;
558                }
559            }
560        }
561
562        Ok(())
563    }
564
565    async fn process_request<Request, F, OP>(&self, message: Message, process: OP)
566    where
567        Request: GenericRequest,
568        F: Future<Output = Option<Request::Response>> + Send,
569        OP: Fn(Request) -> F + Send + Sync,
570    {
571        let Some(reply_subject) = message.reply else {
572            return;
573        };
574
575        let message_payload_size = message.payload.len();
576        let request = match Request::decode(&mut message.payload.as_ref()) {
577            Ok(request) => {
578                // Free allocation early
579                drop(message.payload);
580                request
581            }
582            Err(error) => {
583                warn!(
584                    request_type = %type_name::<Request>(),
585                    %error,
586                    message = %hex::encode(message.payload),
587                    "Failed to decode request"
588                );
589                return;
590            }
591        };
592
593        // Avoid printing large messages in logs
594        if message_payload_size > 1024 {
595            trace!(
596                request_type = %type_name::<Request>(),
597                %reply_subject,
598                "Processing request"
599            );
600        } else {
601            trace!(
602                request_type = %type_name::<Request>(),
603                ?request,
604                %reply_subject,
605                "Processing request"
606            );
607        }
608
609        if let Some(response) = process(request).await
610            && let Err(error) = self.publish(reply_subject, response.encode().into()).await
611        {
612            warn!(
613                request_type = %type_name::<Request>(),
614                %error,
615                "Failed to send response"
616            );
617        }
618    }
619
620    /// Make request that expects stream response
621    pub async fn stream_request<Request>(
622        &self,
623        request: &Request,
624        instance: Option<&str>,
625    ) -> Result<StreamResponseSubscriber<Request::Response>, StreamRequestError>
626    where
627        Request: GenericStreamRequest,
628    {
629        let stream_request_subject = subject_with_instance(Request::SUBJECT, instance);
630        let stream_response_subject = format!("stream-response.{}", Ulid::new());
631
632        let subscriber = self
633            .inner
634            .client
635            .subscribe(stream_response_subject.clone())
636            .await?;
637
638        debug!(
639            request_type = %type_name::<Request>(),
640            %stream_request_subject,
641            %stream_response_subject,
642            ?subscriber,
643            "Stream request subscription"
644        );
645
646        self.inner
647            .client
648            .publish_with_reply(
649                stream_request_subject,
650                stream_response_subject.clone(),
651                request.encode().into(),
652            )
653            .await?;
654
655        Ok(StreamResponseSubscriber::new(
656            subscriber,
657            stream_response_subject,
658            self.clone(),
659        ))
660    }
661
662    /// Responds to stream requests from the given subject using the provided processing function.
663    ///
664    /// This will create a subscription on the subject for the given instance (if provided) and
665    /// queue group. Incoming messages will be deserialized as the request type `Request` and passed
666    /// to the `process` function to produce a stream response of type `Request::Response`. The
667    /// stream response will then be sent back on the reply subject from the original request.
668    ///
669    /// Each request is processed in a newly created async tokio task.
670    ///
671    /// # Arguments
672    ///
673    /// * `instance` - Optional instance name to use in place of the `*` in the subject
674    /// * `group` - The queue group name for the subscription
675    /// * `process` - The function to call with the decoded request to produce a response
676    pub async fn stream_request_responder<Request, F, S, OP>(
677        &self,
678        instance: Option<&str>,
679        queue_group: Option<String>,
680        process: OP,
681    ) -> anyhow::Result<()>
682    where
683        Request: GenericStreamRequest,
684        F: Future<Output = Option<S>> + Send,
685        S: Stream<Item = Request::Response> + Unpin,
686        OP: Fn(Request) -> F + Send + Sync,
687    {
688        // Initialize with pending future so it never ends
689        let mut processing = FuturesUnordered::new();
690
691        let subscription = self
692            .common_subscribe(Request::SUBJECT, instance, queue_group)
693            .await
694            .map_err(|error| {
695                anyhow!(
696                    "Failed to subscribe to {} stream requests for {instance:?}: {error}",
697                    type_name::<Request>(),
698                )
699            })?;
700
701        debug!(
702            request_type = %type_name::<Request>(),
703            ?subscription,
704            "Stream requests subscription"
705        );
706        let mut subscription = subscription.fuse();
707
708        loop {
709            select! {
710                message = subscription.select_next_some() => {
711                    // Create background task for concurrent processing
712                    processing.push(
713                        self
714                        .process_stream_request(
715                            message,
716                            &process,
717                        )
718                        .in_current_span(),
719                    );
720                },
721                _ = processing.next() => {
722                    // Nothing to do here
723                },
724                complete => {
725                    break;
726                }
727            }
728        }
729
730        Ok(())
731    }
732
733    async fn process_stream_request<Request, F, S, OP>(&self, message: Message, process: OP)
734    where
735        Request: GenericStreamRequest,
736        F: Future<Output = Option<S>> + Send,
737        S: Stream<Item = Request::Response> + Unpin,
738        OP: Fn(Request) -> F + Send + Sync,
739    {
740        let Some(reply_subject) = message.reply else {
741            return;
742        };
743
744        let message_payload_size = message.payload.len();
745        let request = match Request::decode(&mut message.payload.as_ref()) {
746            Ok(request) => {
747                // Free allocation early
748                drop(message.payload);
749                request
750            }
751            Err(error) => {
752                warn!(
753                    request_type = %type_name::<Request>(),
754                    %error,
755                    message = %hex::encode(message.payload),
756                    "Failed to decode request"
757                );
758                return;
759            }
760        };
761
762        // Avoid printing large messages in logs
763        if message_payload_size > 1024 {
764            trace!(
765                request_type = %type_name::<Request>(),
766                %reply_subject,
767                "Processing request"
768            );
769        } else {
770            trace!(
771                request_type = %type_name::<Request>(),
772                ?request,
773                %reply_subject,
774                "Processing request"
775            );
776        }
777
778        if let Some(stream) = process(request).await {
779            self.stream_response::<Request, _>(reply_subject, stream)
780                .await;
781        }
782    }
783
784    /// Helper method to send responses to requests initiated with [`Self::stream_request`]
785    async fn stream_response<Request, S>(&self, response_subject: Subject, response_stream: S)
786    where
787        Request: GenericStreamRequest,
788        S: Stream<Item = Request::Response> + Unpin,
789    {
790        type Response<Request> =
791            GenericStreamResponses<<Request as GenericStreamRequest>::Response>;
792
793        let mut response_stream = response_stream.fuse();
794
795        // Pull the first element to measure response size
796        let first_element = match response_stream.next().await {
797            Some(first_element) => first_element,
798            None => {
799                if let Err(error) = self
800                    .publish(
801                        response_subject.clone(),
802                        Response::<Request>::Last {
803                            index: 0,
804                            responses: VecDeque::new(),
805                        }
806                        .encode()
807                        .into(),
808                    )
809                    .await
810                {
811                    warn!(
812                        %response_subject,
813                        %error,
814                        request_type = %type_name::<Request>(),
815                        response_type = %type_name::<Request::Response>(),
816                        "Failed to send stream response"
817                    );
818                }
819
820                return;
821            }
822        };
823        let max_message_size = self.inner.max_message_size;
824        let approximate_max_message_size = self.approximate_max_message_size();
825        let max_responses_per_message = approximate_max_message_size / first_element.encoded_size();
826
827        let ack_subject = format!("stream-response-ack.{}", Ulid::new());
828        let mut ack_subscription = match self.subscribe(ack_subject.clone()).await {
829            Ok(ack_subscription) => ack_subscription,
830            Err(error) => {
831                warn!(
832                    %response_subject,
833                    %error,
834                    request_type = %type_name::<Request>(),
835                    response_type = %type_name::<Request::Response>(),
836                    "Failed to subscribe to ack subject"
837                );
838                return;
839            }
840        };
841        debug!(
842            %response_subject,
843            request_type = %type_name::<Request>(),
844            response_type = %type_name::<Request::Response>(),
845            ?ack_subscription,
846            "Ack subscription subscription"
847        );
848        let mut index = 0;
849        // Initialize buffer that will be reused for responses
850        let mut buffer = VecDeque::with_capacity(max_responses_per_message);
851        buffer.push_back(first_element);
852        let mut overflow_buffer = VecDeque::new();
853
854        loop {
855            // Try to fill the buffer
856            if buffer.is_empty()
857                && let Some(element) = response_stream.next().await
858            {
859                buffer.push_back(element);
860            }
861            while buffer.encoded_size() < approximate_max_message_size
862                && let Some(element) = response_stream.next().now_or_never().flatten()
863            {
864                buffer.push_back(element);
865            }
866
867            loop {
868                let is_done = response_stream.is_done() && overflow_buffer.is_empty();
869                let num_messages = buffer.len();
870                let response = if is_done {
871                    Response::<Request>::Last {
872                        index,
873                        responses: buffer,
874                    }
875                } else {
876                    Response::<Request>::Continue {
877                        index,
878                        responses: buffer,
879                        ack_subject: ack_subject.clone(),
880                    }
881                };
882                let encoded_response = response.encode();
883                let encoded_response_len = encoded_response.len();
884                // When encoded response is too large, remove one of the responses from it and try
885                // again
886                if encoded_response_len > max_message_size {
887                    buffer = response.into();
888                    if let Some(element) = buffer.pop_back() {
889                        if buffer.is_empty() {
890                            error!(
891                                ?element,
892                                encoded_response_len,
893                                max_message_size,
894                                "Element was too large to fit into NATS message, this is an \
895                                implementation bug"
896                            );
897                        }
898                        overflow_buffer.push_front(element);
899                        continue;
900                    } else {
901                        error!(
902                            %response_subject,
903                            request_type = %type_name::<Request>(),
904                            response_type = %type_name::<Request::Response>(),
905                            "Empty response overflown message size, this should never happen"
906                        );
907                        return;
908                    }
909                }
910
911                debug!(
912                    %response_subject,
913                    num_messages,
914                    %index,
915                    %is_done,
916                    "Publishing stream response messages",
917                );
918
919                if let Err(error) = self
920                    .publish(response_subject.clone(), encoded_response.into())
921                    .await
922                {
923                    warn!(
924                        %response_subject,
925                        %error,
926                        request_type = %type_name::<Request>(),
927                        response_type = %type_name::<Request::Response>(),
928                        "Failed to send stream response"
929                    );
930                    return;
931                }
932
933                if is_done {
934                    return;
935                } else {
936                    buffer = response.into();
937                    buffer.clear();
938                    // Fill buffer with any overflown responses that may have been stored
939                    buffer.extend(overflow_buffer.drain(..));
940                }
941
942                if index >= 1 {
943                    // Acknowledgements are received with delay
944                    let expected_index = index - 1;
945
946                    trace!(
947                        %response_subject,
948                        %expected_index,
949                        "Waiting for acknowledgement"
950                    );
951                    match tokio::time::timeout(ACKNOWLEDGEMENT_TIMEOUT, ack_subscription.next())
952                        .await
953                    {
954                        Ok(Some(message)) => {
955                            if let Some(received_index) =
956                                message.payload.get(..size_of::<u32>()).map(|bytes| {
957                                    u32::from_le_bytes(
958                                        bytes.try_into().expect("Correctly chunked slice; qed"),
959                                    )
960                                })
961                            {
962                                debug!(
963                                    %response_subject,
964                                    %received_index,
965                                    "Received acknowledgement"
966                                );
967                                if received_index != expected_index {
968                                    warn!(
969                                        %response_subject,
970                                        %received_index,
971                                        %expected_index,
972                                        request_type = %type_name::<Request>(),
973                                        response_type = %type_name::<Request::Response>(),
974                                        message = %hex::encode(message.payload),
975                                        "Unexpected acknowledgement index"
976                                    );
977                                    return;
978                                }
979                            } else {
980                                warn!(
981                                    %response_subject,
982                                    request_type = %type_name::<Request>(),
983                                    response_type = %type_name::<Request::Response>(),
984                                    message = %hex::encode(message.payload),
985                                    "Unexpected acknowledgement message"
986                                );
987                                return;
988                            }
989                        }
990                        Ok(None) => {
991                            warn!(
992                                %response_subject,
993                                request_type = %type_name::<Request>(),
994                                response_type = %type_name::<Request::Response>(),
995                                "Acknowledgement stream ended unexpectedly"
996                            );
997                            return;
998                        }
999                        Err(_error) => {
1000                            warn!(
1001                                %response_subject,
1002                                %expected_index,
1003                                request_type = %type_name::<Request>(),
1004                                response_type = %type_name::<Request::Response>(),
1005                                "Acknowledgement wait timed out"
1006                            );
1007                            return;
1008                        }
1009                    }
1010                }
1011
1012                index += 1;
1013
1014                // Unless `overflow_buffer` wasn't empty abort inner loop
1015                if buffer.is_empty() {
1016                    break;
1017                }
1018            }
1019        }
1020    }
1021
1022    /// Make notification without waiting for response
1023    pub async fn notification<Notification>(
1024        &self,
1025        notification: &Notification,
1026        instance: Option<&str>,
1027    ) -> Result<(), PublishError>
1028    where
1029        Notification: GenericNotification,
1030    {
1031        self.inner
1032            .client
1033            .publish(
1034                subject_with_instance(Notification::SUBJECT, instance),
1035                notification.encode().into(),
1036            )
1037            .await
1038    }
1039
1040    /// Send a broadcast message
1041    pub async fn broadcast<Broadcast>(
1042        &self,
1043        message: &Broadcast,
1044        instance: &str,
1045    ) -> Result<(), PublishError>
1046    where
1047        Broadcast: GenericBroadcast,
1048    {
1049        self.inner
1050            .client
1051            .publish_with_headers(
1052                Broadcast::SUBJECT.replace('*', instance),
1053                {
1054                    let mut headers = HeaderMap::new();
1055                    if let Some(message_id) = message.deterministic_message_id() {
1056                        headers.insert("Nats-Msg-Id", message_id);
1057                    }
1058                    headers
1059                },
1060                message.encode().into(),
1061            )
1062            .await
1063    }
1064
1065    /// Simple subscription that will produce decoded notifications, while skipping messages that
1066    /// fail to decode
1067    pub async fn subscribe_to_notifications<Notification>(
1068        &self,
1069        instance: Option<&str>,
1070        queue_group: Option<String>,
1071    ) -> Result<SubscriberWrapper<Notification>, SubscribeError>
1072    where
1073        Notification: GenericNotification,
1074    {
1075        self.simple_subscribe(Notification::SUBJECT, instance, queue_group)
1076            .await
1077    }
1078
1079    /// Simple subscription that will produce decoded broadcasts, while skipping messages that
1080    /// fail to decode
1081    pub async fn subscribe_to_broadcasts<Broadcast>(
1082        &self,
1083        instance: Option<&str>,
1084        queue_group: Option<String>,
1085    ) -> Result<SubscriberWrapper<Broadcast>, SubscribeError>
1086    where
1087        Broadcast: GenericBroadcast,
1088    {
1089        self.simple_subscribe(Broadcast::SUBJECT, instance, queue_group)
1090            .await
1091    }
1092
1093    /// Simple subscription that will produce decoded messages, while skipping messages that fail to
1094    /// decode
1095    async fn simple_subscribe<Message>(
1096        &self,
1097        subject: &'static str,
1098        instance: Option<&str>,
1099        queue_group: Option<String>,
1100    ) -> Result<SubscriberWrapper<Message>, SubscribeError>
1101    where
1102        Message: Decode,
1103    {
1104        let subscriber = self
1105            .common_subscribe(subject, instance, queue_group)
1106            .await?;
1107        debug!(
1108            %subject,
1109            message_type = %type_name::<Message>(),
1110            ?subscriber,
1111            "Simple subscription"
1112        );
1113
1114        Ok(SubscriberWrapper {
1115            subscriber,
1116            _phantom: PhantomData,
1117        })
1118    }
1119
1120    /// Simple subscription that will produce decoded messages, while skipping messages that fail to
1121    /// decode
1122    async fn common_subscribe(
1123        &self,
1124        subject: &'static str,
1125        instance: Option<&str>,
1126        queue_group: Option<String>,
1127    ) -> Result<Subscriber, SubscribeError> {
1128        let subscriber = if let Some(queue_group) = queue_group {
1129            self.inner
1130                .client
1131                .queue_subscribe(subject_with_instance(subject, instance), queue_group)
1132                .await?
1133        } else {
1134            self.inner
1135                .client
1136                .subscribe(subject_with_instance(subject, instance))
1137                .await?
1138        };
1139
1140        Ok(subscriber)
1141    }
1142}
1143
1144fn subject_with_instance(subject: &'static str, instance: Option<&str>) -> Subject {
1145    if let Some(instance) = instance {
1146        Subject::from(subject.replace('*', instance))
1147    } else {
1148        Subject::from_static(subject)
1149    }
1150}