1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
use crate::future::FutureExt;
use super::SpawnError;
use futures_channel::oneshot::{self, Sender, Receiver};
use futures_core::future::Future;
use futures_core::task::{self, Poll, Spawn, SpawnObjError};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::marker::Unpin;
use std::mem::PinMut;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;

/// The join handle returned by
/// [`spawn_with_handle`](crate::task::SpawnExt::spawn_with_handle).
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct JoinHandle<T> {
    rx: Receiver<thread::Result<T>>,
    keep_running: Arc<AtomicBool>,
}

impl<T> JoinHandle<T> {
    /// Drops this handle *without* canceling the underlying future.
    ///
    /// This method can be used if you want to drop the handle, but let the
    /// execution continue.
    pub fn forget(self) {
        self.keep_running.store(true, Ordering::SeqCst);
    }
}

impl<T: Send + 'static> Future for JoinHandle<T> {
    type Output = T;

    fn poll(mut self: PinMut<Self>, cx: &mut task::Context) -> Poll<T> {
        match self.rx.poll_unpin(cx) {
            Poll::Ready(Ok(Ok(output))) => Poll::Ready(output),
            Poll::Ready(Ok(Err(e))) => panic::resume_unwind(e),
            Poll::Ready(Err(e)) => panic::resume_unwind(Box::new(e)),
            Poll::Pending => Poll::Pending,
        }
    }
}

struct Wrapped<Fut: Future> {
    tx: Option<Sender<Fut::Output>>,
    keep_running: Arc<AtomicBool>,
    future: Fut,
}

impl<Fut: Future + Unpin> Unpin for Wrapped<Fut> {}

impl<Fut: Future> Wrapped<Fut> {
    unsafe_pinned!(future: Fut);
    unsafe_unpinned!(tx: Option<Sender<Fut::Output>>);
    unsafe_unpinned!(keep_running: Arc<AtomicBool>);
}

impl<Fut: Future> Future for Wrapped<Fut> {
    type Output = ();

    fn poll(mut self: PinMut<Self>, cx: &mut task::Context) -> Poll<()> {
        if let Poll::Ready(_) = self.tx().as_mut().unwrap().poll_cancel(cx) {
            if !self.keep_running().load(Ordering::SeqCst) {
                // Cancelled, bail out
                return Poll::Ready(())
            }
        }

        let output = match self.future().poll(cx) {
            Poll::Ready(output) => output,
            Poll::Pending => return Poll::Pending,
        };

        // if the receiving end has gone away then that's ok, we just ignore the
        // send error here.
        drop(self.tx().take().unwrap().send(output));
        Poll::Ready(())
    }
}

pub(super) fn spawn_with_handle<Sp, Fut>(
    executor: &mut Sp,
    future: Fut,
) -> Result<JoinHandle<Fut::Output>, SpawnError>
where Sp: Spawn + ?Sized,
      Fut: Future + Send + 'static,
      Fut::Output: Send,
{
    let (tx, rx) = oneshot::channel();
    let keep_running = Arc::new(AtomicBool::new(false));

    // AssertUnwindSafe is used here because `Send + 'static` is basically
    // an alias for an implementation of the `UnwindSafe` trait but we can't
    // express that in the standard library right now.
    let wrapped = Wrapped {
        future: AssertUnwindSafe(future).catch_unwind(),
        tx: Some(tx),
        keep_running: keep_running.clone(),
    };

    let res = executor.spawn_obj(Box::new(wrapped).into());
    match res {
        Ok(()) => Ok(JoinHandle { rx, keep_running }),
        Err(SpawnObjError { kind, .. }) => Err(SpawnError { kind }),
    }
}