1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use core::fmt::{Debug, Formatter, Result as FmtResult};
use core::pin::Pin;
use futures_core::task::{Context, Poll};
use futures_sink::Sink;
use pin_utils::unsafe_pinned;

/// Sink that clones incoming items and forwards them to two sinks at the same time.
///
/// Backpressure from any downstream sink propagates up, which means that this sink
/// can only process items as fast as its _slowest_ downstream sink.
pub struct Fanout<Si1, Si2> {
    sink1: Si1,
    sink2: Si2
}

impl<Si1, Si2> Fanout<Si1, Si2> {
    unsafe_pinned!(sink1: Si1);
    unsafe_pinned!(sink2: Si2);

    pub(super) fn new(sink1: Si1, sink2: Si2) -> Fanout<Si1, Si2> {
        Fanout { sink1, sink2 }
    }

    /// Get a shared reference to the inner sinks.
    pub fn get_ref(&self) -> (&Si1, &Si2) {
        (&self.sink1, &self.sink2)
    }

    /// Get a mutable reference to the inner sinks.
    pub fn get_mut(&mut self) -> (&mut Si1, &mut Si2) {
        (&mut self.sink1, &mut self.sink2)
    }

    /// Get a pinned mutable reference to the inner sinks.
    pub fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> (Pin<&'a mut Si1>, Pin<&'a mut Si2>)
        where Si1: Unpin, Si2: Unpin,
    {
        let Self { sink1, sink2 } = Pin::get_mut(self);
        (Pin::new(sink1), Pin::new(sink2))
    }

    /// Consumes this combinator, returning the underlying sinks.
    ///
    /// Note that this may discard intermediate state of this combinator,
    /// so care should be taken to avoid losing resources when this is called.
    pub fn into_inner(self) -> (Si1, Si2) {
        (self.sink1, self.sink2)
    }
}

impl<Si1: Debug, Si2: Debug> Debug for Fanout<Si1, Si2> {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        f.debug_struct("Fanout")
            .field("sink1", &self.sink1)
            .field("sink2", &self.sink2)
            .finish()
    }
}

impl<Si1, Si2, Item> Sink<Item> for Fanout<Si1, Si2>
    where Si1: Sink<Item>,
          Item: Clone,
          Si2: Sink<Item, SinkError=Si1::SinkError>
{
    type SinkError = Si1::SinkError;

    fn poll_ready(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::SinkError>> {
        let sink1_ready = self.as_mut().sink1().poll_ready(cx)?.is_ready();
        let sink2_ready = self.as_mut().sink2().poll_ready(cx)?.is_ready();
        let ready = sink1_ready && sink2_ready;
        if ready { Poll::Ready(Ok(())) } else { Poll::Pending }
    }

    fn start_send(
        mut self: Pin<&mut Self>,
        item: Item,
    ) -> Result<(), Self::SinkError> {
        self.as_mut().sink1().start_send(item.clone())?;
        self.as_mut().sink2().start_send(item)?;
        Ok(())
    }

    fn poll_flush(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::SinkError>> {
        let sink1_ready = self.as_mut().sink1().poll_flush(cx)?.is_ready();
        let sink2_ready = self.as_mut().sink2().poll_flush(cx)?.is_ready();
        let ready = sink1_ready && sink2_ready;
        if ready { Poll::Ready(Ok(())) } else { Poll::Pending }
    }

    fn poll_close(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::SinkError>> {
        let sink1_ready = self.as_mut().sink1().poll_close(cx)?.is_ready();
        let sink2_ready = self.as_mut().sink2().poll_close(cx)?.is_ready();
        let ready = sink1_ready && sink2_ready;
        if ready { Poll::Ready(Ok(())) } else { Poll::Pending }
    }
}

#[cfg(test)]
#[cfg(feature = "std")]
mod tests {
    use crate::future::join3;
    use crate::sink::SinkExt;
    use crate::stream::{self, StreamExt};
    use futures_executor::block_on;
    use futures_channel::mpsc;
    use std::iter::Iterator;
    use std::vec::Vec;

    #[test]
    fn it_works() {
        let (tx1, rx1) = mpsc::channel(1);
        let (tx2, rx2) = mpsc::channel(2);
        let tx = tx1.fanout(tx2).sink_map_err(|_| ());

        let src = stream::iter((0..10).map(Ok));
        let fwd = src.forward(tx);

        let collect_fut1 = rx1.collect::<Vec<_>>();
        let collect_fut2 = rx2.collect::<Vec<_>>();
        let (_, vec1, vec2) = block_on(join3(fwd, collect_fut1, collect_fut2));

        let expected = (0..10).collect::<Vec<_>>();

        assert_eq!(vec1, expected);
        assert_eq!(vec2, expected);
    }
}