"""Async I/O backend support utilities."""

import logging
import socket
import threading
import time
from collections import deque
from queue import Empty
from time import sleep
from weakref import WeakKeyDictionary

from kombu.utils.compat import detect_environment

from celery import states
from celery.exceptions import TimeoutError
from celery.utils.threads import THREAD_TIMEOUT_MAX

E_CELERY_RESTART_REQUIRED = "Celery must be restarted because a shutdown signal was detected."

__all__ = (
    'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
    'register_drainer',
)


class EventletAdaptedEvent:
    """
    An adapted eventlet event, designed to match the API of `threading.Event` and
    `gevent.event.Event`.
    """

    def __init__(self):
        import eventlet
        self.evt = eventlet.Event()

    def is_set(self):
        return self.evt.ready()

    def set(self):
        return self.evt.send()

    def wait(self, timeout=None):
        return self.evt.wait(timeout)


drainers = {}


def register_drainer(name):
    """Decorator used to register a new result drainer type."""
    def _inner(cls):
        drainers[name] = cls
        return cls
    return _inner


@register_drainer('default')
class Drainer:
    """Result draining service."""

    def __init__(self, result_consumer):
        self.result_consumer = result_consumer

    def start(self):
        pass

    def stop(self):
        pass

    def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
        wait = wait or self.result_consumer.drain_events
        time_start = time.monotonic()

        while 1:
            # Total time spent may exceed a single call to wait()
            if timeout and time.monotonic() - time_start >= timeout:
                raise socket.timeout()
            try:
                yield self.wait_for(p, wait, timeout=interval)
            except socket.timeout:
                pass
            if on_interval:
                on_interval()
            if p.ready:  # got event on the wanted channel.
                break

    def wait_for(self, p, wait, timeout=None):
        wait(timeout=timeout)

    def _event(self):
        return threading.Event()


class greenletDrainer(Drainer):
    spawn = None
    _exc = None
    _g = None
    _drain_complete_event = None    # event, sended (and recreated) after every drain_events iteration

    def _send_drain_complete_event(self):
        self._drain_complete_event.set()
        self._drain_complete_event = self._event()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._started = self._event()
        self._stopped = self._event()
        self._shutdown = self._event()
        self._drain_complete_event = self._event()

    def run(self):
        self._started.set()

        try:
            while not self._stopped.is_set():
                try:
                    self.result_consumer.drain_events(timeout=1)
                    self._send_drain_complete_event()
                except socket.timeout:
                    pass
        except Exception as e:
            self._exc = e
            raise
        finally:
            self._send_drain_complete_event()
            try:
                self._shutdown.set()
            except RuntimeError as e:
                logging.error(f"Failed to set shutdown event: {e}")

    def start(self):
        self._ensure_not_shut_down()

        if not self._started.is_set():
            self._g = self.spawn(self.run)
            self._started.wait()

    def stop(self):
        self._stopped.set()
        self._shutdown.wait(THREAD_TIMEOUT_MAX)

    def wait_for(self, p, wait, timeout=None):
        self.start()
        if not p.ready:
            self._drain_complete_event.wait(timeout=timeout)

            self._ensure_not_shut_down()

    def _ensure_not_shut_down(self):
        """Currently used to ensure the drainer has not run to completion.

        Raises if the shutdown event has been signaled (either due to an exception
        or stop() being called).

        The _shutdown event acts as synchronization to ensure _exc is properly
        set before it is read from, avoiding need for locks.
        """
        if self._shutdown.is_set():
            if self._exc is not None:
                raise self._exc
            else:
                raise Exception(E_CELERY_RESTART_REQUIRED)


@register_drainer('eventlet')
class eventletDrainer(greenletDrainer):

    def spawn(self, func):
        from eventlet import sleep, spawn
        g = spawn(func)
        sleep(0)
        return g

    def _event(self):
        return EventletAdaptedEvent()


@register_drainer('gevent')
class geventDrainer(greenletDrainer):

    def spawn(self, func):
        import gevent
        g = gevent.spawn(func)
        gevent.sleep(0)
        return g

    def _event(self):
        from gevent.event import Event
        return Event()


class AsyncBackendMixin:
    """Mixin for backends that enables the async API."""

    def _collect_into(self, result, bucket):
        self.result_consumer.buckets[result] = bucket

    def iter_native(self, result, no_ack=True, **kwargs):
        self._ensure_not_eager()

        results = result.results
        if not results:
            raise StopIteration()

        # we tell the result consumer to put consumed results
        # into these buckets.
        bucket = deque()
        for node in results:
            if not hasattr(node, '_cache'):
                bucket.append(node)
            elif node._cache:
                bucket.append(node)
            else:
                self._collect_into(node, bucket)

        for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
            while bucket:
                node = bucket.popleft()
                if not hasattr(node, '_cache'):
                    yield node.id, node.children
                else:
                    yield node.id, node._cache
        while bucket:
            node = bucket.popleft()
            yield node.id, node._cache

    def add_pending_result(self, result, weak=False, start_drainer=True):
        if start_drainer:
            self.result_consumer.drainer.start()
        try:
            self._maybe_resolve_from_buffer(result)
        except Empty:
            self._add_pending_result(result.id, result, weak=weak)
        return result

    def _maybe_resolve_from_buffer(self, result):
        result._maybe_set_cache(self._pending_messages.take(result.id))

    def _add_pending_result(self, task_id, result, weak=False):
        concrete, weak_ = self._pending_results
        if task_id not in weak_ and result.id not in concrete:
            (weak_ if weak else concrete)[task_id] = result
            self.result_consumer.consume_from(task_id)

    def add_pending_results(self, results, weak=False):
        self.result_consumer.drainer.start()
        return [self.add_pending_result(result, weak=weak, start_drainer=False)
                for result in results]

    def remove_pending_result(self, result):
        self._remove_pending_result(result.id)
        self.on_result_fulfilled(result)
        return result

    def _remove_pending_result(self, task_id):
        for mapping in self._pending_results:
            mapping.pop(task_id, None)

    def on_result_fulfilled(self, result):
        self.result_consumer.cancel_for(result.id)

    def wait_for_pending(self, result,
                         callback=None, propagate=True, **kwargs):
        self._ensure_not_eager()
        for _ in self._wait_for_pending(result, **kwargs):
            pass
        return result.maybe_throw(callback=callback, propagate=propagate)

    def _wait_for_pending(self, result,
                          timeout=None, on_interval=None, on_message=None,
                          **kwargs):
        return self.result_consumer._wait_for_pending(
            result, timeout=timeout,
            on_interval=on_interval, on_message=on_message,
            **kwargs
        )

    @property
    def is_async(self):
        return True


class BaseResultConsumer:
    """Manager responsible for consuming result messages."""

    def __init__(self, backend, app, accept,
                 pending_results, pending_messages):
        self.backend = backend
        self.app = app
        self.accept = accept
        self._pending_results = pending_results
        self._pending_messages = pending_messages
        self.on_message = None
        self.buckets = WeakKeyDictionary()
        self.drainer = drainers[detect_environment()](self)

    def start(self, initial_task_id, **kwargs):
        raise NotImplementedError()

    def stop(self):
        pass

    def drain_events(self, timeout=None):
        raise NotImplementedError()

    def consume_from(self, task_id):
        raise NotImplementedError()

    def cancel_for(self, task_id):
        raise NotImplementedError()

    def _after_fork(self):
        self.buckets.clear()
        self.buckets = WeakKeyDictionary()
        self.on_message = None
        self.on_after_fork()

    def on_after_fork(self):
        pass

    def drain_events_until(self, p, timeout=None, on_interval=None):
        return self.drainer.drain_events_until(
            p, timeout=timeout, on_interval=on_interval)

    def _wait_for_pending(self, result,
                          timeout=None, on_interval=None, on_message=None,
                          **kwargs):
        self.on_wait_for_pending(result, timeout=timeout, **kwargs)
        prev_on_m, self.on_message = self.on_message, on_message
        try:
            for _ in self.drain_events_until(
                    result.on_ready, timeout=timeout,
                    on_interval=on_interval):
                yield
                sleep(0)
        except socket.timeout:
            raise TimeoutError('The operation timed out.')
        finally:
            self.on_message = prev_on_m

    def on_wait_for_pending(self, result, timeout=None, **kwargs):
        pass

    def on_out_of_band_result(self, message):
        self.on_state_change(message.payload, message)

    def _get_pending_result(self, task_id):
        for mapping in self._pending_results:
            try:
                return mapping[task_id]
            except KeyError:
                pass
        raise KeyError(task_id)

    def on_state_change(self, meta, message):
        if self.on_message:
            self.on_message(meta)
        if meta['status'] in states.READY_STATES:
            task_id = meta['task_id']
            try:
                result = self._get_pending_result(task_id)
            except KeyError:
                # send to buffer in case we received this result
                # before it was added to _pending_results.
                self._pending_messages.put(task_id, meta)
            else:
                result._maybe_set_cache(meta)
                buckets = self.buckets
                try:
                    # remove bucket for this result, since it's fulfilled
                    bucket = buckets.pop(result)
                except KeyError:
                    pass
                else:
                    # send to waiter via bucket
                    bucket.append(result)
        sleep(0)
