Skip to main content

ab_networking/protocols/request_response/handlers/
generic_request_handler.rs

1//! Generic request-response handler, typically is used with a type implementing [`GenericRequest`]
2//! to significantly reduce boilerplate when implementing [`RequestHandler`].
3
4use crate::protocols::request_response::request_response_factory::{
5    IncomingRequest, OutgoingResponse, ProtocolConfig, RequestHandler,
6};
7use async_trait::async_trait;
8use futures::channel::mpsc;
9use futures::prelude::*;
10use libp2p::PeerId;
11use parity_scale_codec::{Decode, Encode};
12use std::fmt;
13use std::pin::Pin;
14use std::sync::Arc;
15use tracing::{debug, trace};
16
17/// Could be changed after the production feedback.
18const REQUESTS_BUFFER_SIZE: usize = 50;
19
20/// Generic request with associated response
21pub trait GenericRequest: Encode + Decode + Send + Sync + 'static {
22    /// Defines request-response protocol name.
23    const PROTOCOL_NAME: &'static str;
24    /// Specifies log-parameters for tracing.
25    const LOG_TARGET: &'static str;
26    /// Response type that corresponds to this request
27    type Response: Encode + Decode + Send + Sync + 'static;
28}
29
30type RequestHandlerFn<Request> = Arc<
31    dyn (Fn(
32            PeerId,
33            Request,
34        )
35            -> Pin<Box<dyn Future<Output = Option<<Request as GenericRequest>::Response>> + Send>>)
36        + Send
37        + Sync
38        + 'static,
39>;
40
41/// Defines generic request-response protocol handler.
42pub struct GenericRequestHandler<Request>
43where
44    Request: GenericRequest,
45{
46    request_receiver: mpsc::Receiver<IncomingRequest>,
47    request_handler: RequestHandlerFn<Request>,
48    protocol_config: ProtocolConfig,
49}
50
51impl<Request> fmt::Debug for GenericRequestHandler<Request>
52where
53    Request: GenericRequest,
54{
55    #[inline]
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("GenericRequestHandler")
58            .field("protocol_name", &Request::PROTOCOL_NAME)
59            .finish_non_exhaustive()
60    }
61}
62
63impl<Request> GenericRequestHandler<Request>
64where
65    Request: GenericRequest,
66{
67    /// Creates new [`GenericRequestHandler`] by given handler.
68    pub fn create<RH, Fut>(request_handler: RH) -> Box<dyn RequestHandler>
69    where
70        RH: (Fn(PeerId, Request) -> Fut) + Send + Sync + 'static,
71        Fut: Future<Output = Option<Request::Response>> + Send + 'static,
72    {
73        let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
74
75        let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
76        protocol_config.inbound_queue = Some(request_sender);
77
78        Box::new(Self {
79            request_receiver,
80            request_handler: Arc::new(move |peer_id, request| {
81                Box::pin(request_handler(peer_id, request))
82            }),
83            protocol_config,
84        })
85    }
86
87    /// Invokes external protocol handler.
88    async fn handle_request(
89        &self,
90        peer: PeerId,
91        payload: Vec<u8>,
92    ) -> Result<Vec<u8>, RequestHandlerError> {
93        trace!(%peer, protocol=Request::LOG_TARGET, "Handling request...");
94        let request = Request::decode(&mut payload.as_slice())
95            .map_err(|_error| RequestHandlerError::InvalidRequestFormat)?;
96        let response = (self.request_handler)(peer, request).await;
97
98        Ok(response.ok_or(RequestHandlerError::NoResponse)?.encode())
99    }
100}
101
102#[async_trait]
103impl<Request> RequestHandler for GenericRequestHandler<Request>
104where
105    Request: GenericRequest,
106{
107    /// Run [`RequestHandler`].
108    async fn run(&mut self) {
109        while let Some(request) = self.request_receiver.next().await {
110            let IncomingRequest {
111                peer,
112                payload,
113                pending_response,
114            } = request;
115
116            match self.handle_request(peer, payload).await {
117                Ok(response_data) => {
118                    let response = OutgoingResponse {
119                        result: Ok(response_data),
120                        sent_feedback: None,
121                    };
122
123                    if pending_response.send(response).is_ok() {
124                        trace!(target = Request::LOG_TARGET, %peer, "Handled request");
125                    } else {
126                        debug!(
127                            target = Request::LOG_TARGET,
128                            protocol = Request::PROTOCOL_NAME,
129                            %peer,
130                            "Failed to handle request: {}",
131                            RequestHandlerError::SendResponse
132                        );
133                    }
134                }
135                Err(e) => {
136                    debug!(
137                        target = Request::LOG_TARGET,
138                        protocol = Request::PROTOCOL_NAME,
139                        %e,
140                        "Failed to handle request.",
141                    );
142
143                    let response = OutgoingResponse {
144                        result: Err(()),
145                        sent_feedback: None,
146                    };
147
148                    if pending_response.send(response).is_err() {
149                        debug!(
150                            target = Request::LOG_TARGET,
151                            protocol = Request::PROTOCOL_NAME,
152                            %peer,
153                            "Failed to handle request: {}", RequestHandlerError::SendResponse
154                        );
155                    }
156                }
157            }
158        }
159    }
160
161    fn protocol_config(&self) -> ProtocolConfig {
162        self.protocol_config.clone()
163    }
164
165    fn protocol_name(&self) -> &'static str {
166        Request::PROTOCOL_NAME
167    }
168
169    fn clone_box(&self) -> Box<dyn RequestHandler> {
170        let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
171
172        let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
173        protocol_config.inbound_queue = Some(request_sender);
174
175        Box::new(Self {
176            request_receiver,
177            request_handler: Arc::clone(&self.request_handler),
178            protocol_config,
179        })
180    }
181}
182
183#[derive(Debug, thiserror::Error)]
184enum RequestHandlerError {
185    #[error("Failed to send response.")]
186    SendResponse,
187
188    #[error("Incorrect request format.")]
189    InvalidRequestFormat,
190
191    #[error("No response.")]
192    NoResponse,
193}