# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import abc
import asyncio
import contextlib
import subprocess
import time as _time
import typing as tp
import uuid
import warnings
from pathlib import Path

from typing_extensions import TypedDict

from . import logger, utils

# R as in "Result", so yes it's covariant.
# pylint: disable=typevar-name-incorrect-variance
R = tp.TypeVar("R", covariant=True)


class InfoWatcher:
    """An instance of this class is shared by all jobs, and is in charge of calling slurm to check status for
    all jobs at once (so as not to overload it). It is also in charge of dealing with errors.
    Cluster is called at 0s, 2s, 4s, 8s etc... in the begginning of jobs, then at least every delay_s (default: 60)

    Parameters
    ----------
    delay_s: int
        Maximum delay before each non-forced call to the cluster.
    """

    # pylint: disable=too-many-instance-attributes

    def __init__(self, delay_s: int = 60) -> None:
        self._delay_s = delay_s
        self._registered: tp.Set[str] = set()
        self._finished: tp.Set[str] = set()
        self._info_dict: tp.Dict[str, tp.Dict[str, str]] = {}
        self._output = b""  # for the record
        self._start_time = 0.0
        self._last_status_check = float("-inf")
        self._num_calls = 0

    def read_info(self, string: tp.Union[bytes, str]) -> tp.Dict[str, tp.Dict[str, str]]:
        raise NotImplementedError

    def _make_command(self) -> tp.Optional[tp.List[str]]:
        raise NotImplementedError

    def get_state(self, job_id: str, mode: str = "standard") -> str:
        raise NotImplementedError

    @property
    def num_calls(self) -> int:
        """Number of calls to sacct"""
        return self._num_calls

    def clear(self) -> None:
        """Clears cache.
        This should hopefully not be used. If you have to use it, please add a github issue.
        """
        self._finished = set()
        self._start_time = _time.time()
        self._last_status_check = float("-inf")
        self._info_dict = {}
        self._output = b""

    def get_info(self, job_id: str, mode: str = "standard") -> tp.Dict[str, str]:
        """Returns a dict containing info about the job.
        State of finished jobs are cached (use watcher.clear() to remove all cache)

        Parameters
        ----------
        job_id: str
            id of the job on the cluster
        mode: str
            one of "force" (forces a call), "standard" (calls regularly) or "cache" (does not call)
        """
        if job_id is None:
            raise RuntimeError("Cannot call sacct without a slurm id")
        if job_id not in self._registered:
            self.register_job(job_id)
        # check with a call to sacct/cinfo
        self.update_if_long_enough(mode)
        return self._info_dict.get(job_id, {})

    def is_done(self, job_id: str, mode: str = "standard") -> bool:
        """Returns whether the job is finished.

        Parameters
        ----------
        job_id: str
            id of the job on the cluster
        mode: str
            one of "force" (forces a call), "standard" (calls regularly) or "cache" (does not call)
        """
        state = self.get_state(job_id, mode=mode)
        incomplete = ["READY", "PENDING", "RUNNING", "UNKNOWN", "REQUEUED", "COMPLETING", "PREEMPTED"]
        return state.upper() not in incomplete

    def update_if_long_enough(self, mode: str) -> None:
        """Updates if forced to, or if the delay is reached
        (Force-updates with less than 1ms delay are ignored)
        Also checks for finished jobs
        """
        assert mode in ["standard", "force", "cache"]
        if mode == "cache":
            return
        last_check_delta = _time.time() - self._last_status_check
        last_job_delta = _time.time() - self._start_time
        refresh_delay = min(self._delay_s, max(2, last_job_delta / 2))
        if mode == "force":
            refresh_delay = 0.001

        # the following will call update at time 0s, 2s, 4, 8, 16, 32, 64, 124 (delta 60), 184 (delta 60) etc... of last added job
        # (for delay_s = 60)
        if last_check_delta > refresh_delay:
            self.update()

    def update(self) -> None:
        """Updates the info of all registered jobs with a call to sacct"""
        command = self._make_command()
        if command is None:
            return
        self._num_calls += 1
        try:
            logger.get_logger().debug(f"Call #{self.num_calls} - Command {' '.join(command)}")
            self._output = subprocess.check_output(command, shell=False)
        except Exception as e:
            logger.get_logger().warning(
                f"Call #{self.num_calls} - Bypassing sacct error {e}, status may be inaccurate."
            )
        else:
            self._info_dict.update(self.read_info(self._output))
        self._last_status_check = _time.time()
        # check for finished jobs
        to_check = self._registered - self._finished
        for job_id in to_check:
            if self.is_done(job_id, mode="cache"):
                self._finished.add(job_id)

    def register_job(self, job_id: str) -> None:
        """Register a job on the instance for shared update"""
        assert isinstance(job_id, str)
        self._registered.add(job_id)
        self._start_time = _time.time()
        self._last_status_check = float("-inf")


# pylint: disable=too-many-public-methods
class Job(tp.Generic[R]):
    """Access to a cluster job information and result.

    Parameters
    ----------
    folder: Path/str
        A path to the submitted job file
    job_id: str
        the id of the cluster job
    tasks: List[int]
        The ids of the tasks associated to this job.
        If None, the job has only one task (with id = 0)
    """

    _cancel_command = "dummy"
    _results_timeout_s = 15
    watcher = InfoWatcher()

    def __init__(self, folder: tp.Union[Path, str], job_id: str, tasks: tp.Sequence[int] = (0,)) -> None:
        self._job_id = job_id
        self._tasks = tuple(tasks)
        self._sub_jobs: tp.Sequence["Job[R]"] = []
        self._cancel_at_deletion = False
        if len(tasks) > 1:
            # This is a meta-Job
            self._sub_jobs = [self.__class__(folder=folder, job_id=job_id, tasks=(k,)) for k in self._tasks]
        self._paths = utils.JobPaths(folder, job_id=job_id, task_id=self.task_id)
        self._start_time = _time.time()
        self._last_status_check = self._start_time  # for the "done()" method
        # register for state updates with watcher
        self._register_in_watcher()

    def _register_in_watcher(self) -> None:
        if not self._tasks[0]:  # only register for task=0
            self.watcher.register_job(self.job_id)

    @property
    def job_id(self) -> str:
        return self._job_id

    @property
    def paths(self) -> utils.JobPaths:
        return self._paths

    @property
    def num_tasks(self) -> int:
        """Returns the number of tasks in the Job"""
        if not self._sub_jobs:
            return 1
        return len(self._sub_jobs)

    def submission(self) -> utils.DelayedSubmission:
        """Returns the submitted object, with attributes `function`, `args` and `kwargs`"""
        assert (
            self.paths.submitted_pickle.exists()
        ), f"Cannot find job submission pickle: {self.paths.submitted_pickle}"
        return utils.DelayedSubmission.load(self.paths.submitted_pickle)

    def cancel_at_deletion(self, value: bool = True) -> "Job[R]":
        """Sets whether the job deletion in the python environment triggers
        cancellation of the corresponding job in the cluster
        By default, jobs are not cancelled unless this method is called to turn the
        option on.

        Parameters
        ----------
        value: bool
            if True, the cluster job will be cancelled at the instance deletion, if False, it
            will not.

        Returns
        -------
        Job
            the current job (for chaining at submission for instance: "job = executor.submit(...).cancel_at_deletion()")
        """
        self._cancel_at_deletion = value
        return self

    def task(self, task_id: int) -> "Job[R]":
        """Returns a given sub-Job (task).

        Parameters
        ----------
        task_id
            The id of the task. Must be between 0 and self.num_tasks
        Returns
        -------
        job
            The sub_job. You can call all Job methods on it (done, stdout, ...)
            If the job doesn't have sub jobs, return the job itself.
        """
        if not 0 <= task_id < self.num_tasks:
            raise ValueError(f"task_id {task_id} must be between 0 and {self.num_tasks - 1}")

        if not self._sub_jobs:
            return self
        return self._sub_jobs[task_id]

    def cancel(self, check: bool = True) -> None:
        """Cancels the job

        Parameters
        ----------
        check: bool
            whether to wait for completion and check that the command worked
        """
        (subprocess.check_call if check else subprocess.call)(
            [self._cancel_command, f"{self.job_id}"], shell=False
        )

    def result(self) -> R:
        r = self.results()
        assert not self._sub_jobs, "You should use `results()` if your job has subtasks."
        return r[0]

    def results(self) -> tp.List[R]:
        """Waits for and outputs the result of the submitted function

        Returns
        -------
        output
            the output of the submitted function.
            If the job has several tasks, it will return the output of every tasks in a List

        Raises
        ------
        Exception
            Any exception raised by the job
        """
        self.wait()

        if self._sub_jobs:
            return [tp.cast(R, sub_job.result()) for sub_job in self._sub_jobs]

        outcome, result = self._get_outcome_and_result()
        if outcome == "error":
            job_exception = self.exception()
            if job_exception is None:
                raise RuntimeError("Unknown job exception")
            raise job_exception  # pylint: disable=raising-bad-type
        return [result]

    def exception(self) -> tp.Optional[tp.Union[utils.UncompletedJobError, utils.FailedJobError]]:
        """Waits for completion and returns (not raise) the
        exception containing the error log of the job

        Returns
        -------
        Exception/None
            the exception if any was raised during the job.
            If the job has several tasks, it returns the exception of the task with
            smallest id that failed.

        Raises
        ------
        UncompletedJobError
            In case the job never completed
        """
        self.wait()

        if self._sub_jobs:
            all_exceptions = [sub_job.exception() for sub_job in self._sub_jobs]
            # unexpected pylint issue on correct code:
            exceptions = [
                e for e in all_exceptions if e is not None  # pylint: disable=used-before-assignment
            ]
            if not exceptions:
                return None
            return exceptions[0]

        try:
            outcome, trace = self._get_outcome_and_result()
        except utils.UncompletedJobError as e:
            return e
        if outcome == "error":
            return utils.FailedJobError(
                f"Job (task={self.task_id}) failed during processing with trace:\n"
                f"----------------------\n{trace}\n"
                "----------------------\n"
                f"You can check full logs with 'job.stderr({self.task_id})' and 'job.stdout({self.task_id})'"
                f"or at paths:\n  - {self.paths.stderr}\n  - {self.paths.stdout}"
            )
        return None

    def _get_outcome_and_result(self) -> tp.Tuple[str, tp.Any]:
        """Getter for the output of the submitted function.

        Returns
        -------
        outcome
            the outcome of the job: either "error" or "success"
        result
            the output of the submitted function

        Raises
        ------
        UncompletedJobError
            if the job is not finished or failed outside of the job (from slurm)
        """
        assert not self._sub_jobs, "This should not be called for a meta-job"

        p = self.paths.folder
        timeout = self._results_timeout_s
        try:
            # trigger cache update: https://stackoverflow.com/questions/3112546/os-path-exists-lies/3112717
            p.chmod(p.stat().st_mode)
        except PermissionError:
            # chmod requires file ownership and might fail.
            # Increase the timeout since we can't force cache refresh.
            timeout *= 2
        # if filesystem is slow, we need to wait a bit for result_pickle.
        start_wait = _time.time()
        while not self.paths.result_pickle.exists() and _time.time() - start_wait < timeout:
            _time.sleep(1)
        if not self.paths.result_pickle.exists():
            message = [
                f"Job {self.job_id} (task: {self.task_id}) with path {self.paths.result_pickle}",
                f"has not produced any output (state: {self.state})",
            ]
            log = self.stderr()
            if log:
                message.extend(["Error stream produced:", "-" * 40, log])
            elif self.paths.stdout.exists():
                log = subprocess.check_output(["tail", "-40", str(self.paths.stdout)], encoding="utf-8")
                message.extend(
                    [f"No error stream produced. Look at stdout: {self.paths.stdout}", "-" * 40, log]
                )
            else:
                message.append(f"No output/error stream produced ! Check: {self.paths.stdout}")
            raise utils.UncompletedJobError("\n".join(message))
        try:
            output: tp.Tuple[str, tp.Any] = utils.pickle_load(self.paths.result_pickle)
        except EOFError:
            warnings.warn(f"EOFError on file {self.paths.result_pickle}, trying again in 2s")  # will it work?
            _time.sleep(2)
            output = utils.pickle_load(self.paths.result_pickle)
        return output

    def wait(self) -> None:
        """Wait while no result find is found and the state is
        either PENDING or RUNNING.
        The state is checked from slurm at least every min and the result path
        every second.
        """
        while not self.done():
            _time.sleep(1)

    def done(self, force_check: bool = False) -> bool:
        """Checks whether the job is finished.
        This is done by checking if the result file is present,
        or checking the job state regularly (at least every minute)
        If the job has several tasks, the job is done once all tasks are done.

        Parameters
        ----------
        force_check: bool
            Forces the slurm state update

        Returns
        -------
        bool
            whether the job is finished or not

        Note
        ----
        This function is not foolproof, and may say that the job is not terminated even
        if it is when the job failed (no result file, but job not running) because
        we avoid calling sacct/cinfo everytime done is called
        """
        # TODO: keep state info once job is finished?
        if self._sub_jobs:
            return all(sub_job.done() for sub_job in self._sub_jobs)
        p = self.paths.folder
        try:
            # trigger cache update: https://stackoverflow.com/questions/3112546/os-path-exists-lies/3112717
            p.chmod(p.stat().st_mode)
        except OSError:
            pass
        if self.paths.result_pickle.exists():
            return True
        # check with a call to sacct/cinfo
        if self.watcher.is_done(self.job_id, mode="force" if force_check else "standard"):
            return True
        return False

    @property
    def task_id(self) -> tp.Optional[int]:
        return None if len(self._tasks) > 1 else self._tasks[0]

    @property
    def state(self) -> str:
        """State of the job (does not force an update)"""
        return self.watcher.get_state(self.job_id, mode="standard")

    def get_info(self, mode: str = "force") -> tp.Dict[str, str]:
        """Returns informations about the job as a dict (sacct call)"""
        return self.watcher.get_info(self.job_id, mode=mode)

    def _get_logs_string(self, name: str) -> tp.Optional[str]:
        """Returns a string with the content of the log file
        or None if the file does not exist yet

        Parameter
        ---------
        name: str
            either "stdout" or "stderr"
        """
        paths = {"stdout": self.paths.stdout, "stderr": self.paths.stderr}
        if name not in paths:
            raise ValueError(f'Unknown "{name}", available are {list(paths.keys())}')
        if not paths[name].exists():
            return None
        with paths[name].open("r") as f:
            string: str = f.read()
        return string

    def stdout(self) -> tp.Optional[str]:
        """Returns a string with the content of the print log file
        or None if the file does not exist yet
        """
        if self._sub_jobs:
            stdout_ = [sub_job.stdout() for sub_job in self._sub_jobs]
            stdout_not_none = [s for s in stdout_ if s is not None]
            if not stdout_not_none:
                return None
            return "\n".join(stdout_not_none)

        return self._get_logs_string("stdout")

    def stderr(self) -> tp.Optional[str]:
        """Returns a string with the content of the error log file
        or None if the file does not exist yet
        """
        if self._sub_jobs:
            stderr_ = [sub_job.stderr() for sub_job in self._sub_jobs]
            stderr_not_none: tp.List[str] = [s for s in stderr_ if s is not None]
            if not stderr_not_none:
                return None
            return "\n".join(stderr_not_none)
        return self._get_logs_string("stderr")

    def awaitable(self) -> "AsyncJobProxy[R]":
        """Returns a proxy object that provides asyncio methods
        for this Job.
        """
        return AsyncJobProxy(self)

    def __repr__(self) -> str:
        state = "UNKNOWN"
        try:
            state = self.state
        except Exception as e:
            logger.get_logger().warning(f"Bypassing state error:\n{e}")
        return f'{self.__class__.__name__}<job_id={self.job_id}, task_id={self.task_id}, state="{state}">'

    def __del__(self) -> None:
        if self._cancel_at_deletion:
            if not self.watcher.is_done(self.job_id, mode="cache"):
                self.cancel(check=False)

    def __getstate__(self) -> tp.Dict[str, tp.Any]:
        return self.__dict__  # for pickling (see __setstate__)

    def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
        """Make sure jobs are registered when loaded from a pickle"""
        self.__dict__.update(state)
        self._register_in_watcher()


class DelayedJob(Job[R]):
    """
    Represents a Job that have been queue for submission by an executor,
    but hasn't yet been scheduled.
    Typically obtained by calling `ex.submit` within a `ex.batch()` context

    Trying to read the attributes of the job will, by default, fail.
    But if you passed `ex.batch(allow_implicit_submission=True)` then
    the attribute read will in fact force the job submission,
    and you'll obtain a real job instead.
    """

    def __init__(self, ex: "Executor"):
        # pylint: disable = super-init-not-called
        self._submitit_executor = ex

    def __getattr__(self, name: str) -> tp.Any:
        # _cancel_at_deletion is used in __del__, we don't want it to trigger submission
        if name == "_cancel_at_deletion":
            return False

        ex = self.__dict__["_submitit_executor"]
        # this submits the batch so as to fill the instance attributes
        # this may return false if we try to submit within executor.batch()
        # without passing `executor.batch(allow_implicit_submission=True)`
        if not ex._allow_implicit_submissions:
            raise AttributeError(
                "Accesssing job attributes is forbidden within 'with executor.batch()' context"
            )
        ex._submit_delayed_batch()
        # Ensure that _promote did get called, otherwise getattr will trigger a stack overflow
        assert self.__class__ != DelayedJob, f"Executor {ex} didn't properly submitted {self} !"
        return getattr(self, name)

    def _promote(self, new_job: Job[tp.Any]) -> None:
        # fill in the empty shell, the pickle way
        self.__dict__.pop("_submitit_executor", None)
        self.__dict__.update(new_job.__dict__)
        # pylint: disable=attribute-defined-outside-init
        self.__class__ = new_job.__class__  # type: ignore

    def __repr__(self) -> str:
        return object.__repr__(self)


class AsyncJobProxy(tp.Generic[R]):
    def __init__(self, job: Job[R]):
        self.job = job

    async def wait(self, poll_interval: tp.Union[int, float] = 1) -> None:
        """same as wait() but with asyncio sleep."""
        while not self.job.done():
            await asyncio.sleep(poll_interval)

    async def result(self, poll_interval: tp.Union[int, float] = 1) -> R:
        """asyncio version of the result() method.
        Wait asynchornously for the result to be available by polling the self.done() method.
        Parameters
        ----------
        poll_interval: int or float
            how often to check if the result is available, in seconds
        """
        await self.wait(poll_interval)
        return self.job.result()

    async def results(self, poll_interval: tp.Union[int, float] = 1) -> tp.List[R]:
        """asyncio version of the results() method.

        Waits asynchornously for ALL the results to be available by polling the self.done() method.

        Parameters
        ----------
        poll_interval: int or float
            how often to check if the result is available, in seconds
        """
        await self.wait(poll_interval)
        # results are ready now
        return self.job.results()

    def results_as_completed(self, poll_interval: tp.Union[int, float] = 1) -> tp.Iterator[asyncio.Future]:
        """awaits for all tasks results concurrently. Note that the order of results is not guaranteed to match the order
        of the tasks anymore as the earliest task coming back might not be the first one you sent.

        Returns
        -------
        an iterable of Awaitables that can be awaited on to get the earliest result available of the remaining tasks.

        Parameters
        ----------
        poll_interval: int or float
            how often to check if the result is available, in seconds

        (see https://docs.python.org/3/library/asyncio-task.html#asyncio.as_completed)
        """
        if self.job.num_tasks > 1:
            yield from asyncio.as_completed(
                [self.job.task(i).awaitable().result(poll_interval) for i in range(self.job.num_tasks)]
            )

        # there is only one result anyway, let's just use async result
        yield asyncio.ensure_future(self.result())


_MSG = (
    "Interactions with jobs are not allowed within "
    '"with executor.batch()" context (submissions/creations only happens at exit time).'
)


class EquivalenceDict(TypedDict):
    """Gives the specific name of the params shared across all plugins."""

    # Note that all values are typed as string, even though they correspond to integer.
    # This allow to have a static typing on the "_equivalence_dict" method implemented
    # by plugins.
    # We could chose to put the proper types, but that wouldn't be enough to typecheck
    # the calls to `update_parameters` which uses kwargs.
    name: str
    timeout_min: str
    mem_gb: str
    nodes: str
    cpus_per_task: str
    gpus_per_node: str
    tasks_per_node: str


class Executor(abc.ABC):
    """Base job executor.

    Parameters
    ----------
    folder: Path/str
        folder for storing job submission/output and logs.
    """

    job_class: tp.Type[Job[tp.Any]] = Job

    def __init__(self, folder: tp.Union[str, Path], parameters: tp.Optional[tp.Dict[str, tp.Any]] = None):
        self.folder = Path(folder).expanduser().absolute()
        self.parameters = {} if parameters is None else parameters
        # storage for the batch context manager, for batch submissions:
        self._delayed_batch: tp.Optional[tp.List[tp.Tuple[Job[tp.Any], utils.DelayedSubmission]]] = None
        self._allow_implicit_submissions = False

    @classmethod
    def name(cls) -> str:
        n = cls.__name__
        if n.endswith("Executor"):
            n = n[: -len("Executor")]
        return n.lower()

    @contextlib.contextmanager
    def batch(self, allow_implicit_submissions: bool = False) -> tp.Iterator[None]:
        """Creates a context within which all submissions are packed into a job array.
        By default the array submissions happens when leaving the context

        Parameter
        ---------
        allow_implicit_submissions: bool
            submits the current batch whenever a job attribute is accessed instead of raising an exception

        Example
        -------
        jobs = []
        with executor.batch():
            for k in range(12):
                jobs.append(executor.submit(add, k, 1))

        Raises
        ------
        AttributeError
            if trying to access a job instance attribute while the batch is not exited, and
            intermediate submissions are not allowed.
        """
        self._allow_implicit_submissions = allow_implicit_submissions
        if self._delayed_batch is not None:
            raise RuntimeError('Nesting "with executor.batch()" contexts is not allowed.')
        self._delayed_batch = []
        try:
            yield None
        except Exception as e:
            logger.get_logger().error(
                'Caught error within "with executor.batch()" context, submissions are dropped.\n '
            )
            raise e
        else:
            self._submit_delayed_batch()
        finally:
            self._delayed_batch = None

    def _submit_delayed_batch(self) -> None:
        assert self._delayed_batch is not None
        if not self._delayed_batch:
            if not self._allow_implicit_submissions:
                warnings.warn(
                    'No submission happened during "with executor.batch()" context.', category=RuntimeWarning
                )
            return
        jobs, submissions = zip(*self._delayed_batch)
        new_jobs = self._internal_process_submissions(submissions)
        for j, new_j in zip(jobs, new_jobs):
            j._promote(new_j)
        self._delayed_batch = []

    def submit(self, fn: tp.Callable[..., R], *args: tp.Any, **kwargs: tp.Any) -> Job[R]:
        ds = utils.DelayedSubmission(fn, *args, **kwargs)
        if self._delayed_batch is not None:
            job: Job[R] = DelayedJob(self)
            self._delayed_batch.append((job, ds))
        else:
            job = self._internal_process_submissions([ds])[0]
        if type(job) is Job:  # pylint: disable=unidiomatic-typecheck
            raise RuntimeError("Executors should never return a base Job class (implementation issue)")
        return job

    @abc.abstractmethod
    def _internal_process_submissions(
        self, delayed_submissions: tp.List[utils.DelayedSubmission]
    ) -> tp.List[Job[tp.Any]]:
        ...

    def map_array(self, fn: tp.Callable[..., R], *iterable: tp.Iterable[tp.Any]) -> tp.List[Job[R]]:
        """A distributed equivalent of the map() built-in function

        Parameters
        ----------
        fn: callable
            function to compute
        *iterable: Iterable
            lists of arguments that are passed as arguments to fn.

        Returns
        -------
        List[Job]
            A list of Job instances.

        Example
        -------
        a = [1, 2, 3]
        b = [10, 20, 30]
        executor.map_array(add, a, b)
        # jobs will compute 1 + 10, 2 + 20, 3 + 30
        """
        submissions = [utils.DelayedSubmission(fn, *args) for args in zip(*iterable)]
        if len(submissions) == 0:
            warnings.warn("Received an empty job array")
            return []
        return self._internal_process_submissions(submissions)

    def submit_array(self, fns: tp.Sequence[tp.Callable[[], R]]) -> tp.List[Job[R]]:
        """Submit a list of job. This is useful when submiting different Checkpointable functions.
        Be mindful that all those functions will be run with the same requirements
        (cpus, gpus, timeout, ...). So try to make group of similar function calls.

        Parameters
        ----------
        fns: list of callable
            functions to compute. Those functions must not need any argument.
            Tyically those are "Checkpointable" instance whose arguments
            have been specified in the constructor, or partial functions.

        Returns
        -------
        List[Job]
            A list of Job instances.

        Example
        -------
        a_vals = [1, 2, 3]
        b_vals = [10, 20, 30]
        fns = [functools.partial(int.__add__, a, b) for (a, b) in zip (a_vals, b_vals)]
        executor.submit_array(fns)
        # jobs will compute 1 + 10, 2 + 20, 3 + 30
        """
        submissions = [utils.DelayedSubmission(fn) for fn in fns]
        if len(submissions) == 0:
            warnings.warn("Received an empty job array")
            return []
        return self._internal_process_submissions(submissions)

    def update_parameters(self, **kwargs: tp.Any) -> None:
        """Update submision parameters."""
        if self._delayed_batch is not None:
            raise RuntimeError(
                'Changing parameters within batch context "with executor.batch():" is not allowed'
            )
        self._internal_update_parameters(**kwargs)

    @classmethod
    def _equivalence_dict(cls) -> tp.Optional[EquivalenceDict]:
        return None

    @classmethod
    def _valid_parameters(cls) -> tp.Set[str]:
        """Parameters that can be set through update_parameters"""
        return set()

    def _convert_parameters(self, params: tp.Dict[str, tp.Any]) -> tp.Dict[str, tp.Any]:
        """Convert generic parameters to their specific equivalent.
        This has to be called **before** calling `update_parameters`.

        The default implementation only renames the key using `_equivalence_dict`.
        """
        eq_dict = tp.cast(tp.Optional[tp.Dict[str, str]], self._equivalence_dict())
        if eq_dict is None:
            return params
        return {eq_dict.get(k, k): v for k, v in params.items()}

    def _internal_update_parameters(self, **kwargs: tp.Any) -> None:
        """Update submission parameters."""
        self.parameters.update(kwargs)

    @classmethod
    def affinity(cls) -> int:
        """The 'score' of this executor on the current environment.

        -> -1 means unavailable
        ->  0 means available but won't be started unless asked (eg debug executor)
        ->  1 means available
        ->  2 means available and is a highly scalable executor (cluster)
        """
        return 1


class PicklingExecutor(Executor):
    """Base job executor.

    Parameters
    ----------
    folder: Path/str
        folder for storing job submission/output and logs.
    max_num_timeout: int
        maximum number of timeouts after which submitit will not reschedule the job.
        Note: only callable implementing a checkpoint method are rescheduled in case
        of timeout.
    max_pickle_size_gb: float
        maximum size of pickles in GB allowed for a submission.
        Note: during a batch submission, this is the estimated sum of all pickles.
    """

    def __init__(self, folder: tp.Union[Path, str], max_num_timeout: int = 3, max_pickle_size_gb: float = 1.0) -> None:
        super().__init__(folder)
        self.max_num_timeout = max_num_timeout
        self.max_pickle_size_gb = max_pickle_size_gb
        self._throttling = 0.2
        self._last_job_submitted = 0.0

    def _internal_process_submissions(
        self, delayed_submissions: tp.List[utils.DelayedSubmission]
    ) -> tp.List[Job[tp.Any]]:
        """Submits a task to the cluster.

        Parameters
        ----------
        fn: callable
            The function to compute
        *args: any positional argument for the function
        **kwargs: any named argument for the function

        Returns
        -------
        Job
            A Job instance, providing access to the job information,
            including the output of the function once it is computed.
        """
        eq_dict = self._equivalence_dict()
        timeout_min = self.parameters.get(eq_dict["timeout_min"] if eq_dict else "timeout_min", 5)
        jobs = []
        check_size = True
        for delayed in delayed_submissions:
            tmp_uuid = uuid.uuid4().hex
            pickle_path = utils.JobPaths.get_first_id_independent_folder(self.folder) / f"{tmp_uuid}.pkl"
            pickle_path.parent.mkdir(parents=True, exist_ok=True)
            delayed.set_timeout(timeout_min, self.max_num_timeout)
            delayed.dump(pickle_path)
            if check_size:  # warn if the dumped objects are too big
                check_size = False
                num = len(delayed_submissions)
                size =  pickle_path.stat().st_size / 1024**3
                if num * size > self.max_pickle_size_gb:
                    pickle_path.unlink()
                    msg = f"Submitting an estimated {num} x {size:.2f} > {self.max_pickle_size_gb}GB of objects "
                    msg += "(function and arguments) through pickle (this can be slow / overload the file system)."
                    msg += "If this is the intended behavior, you should update executor.max_pickle_size_gb to a larger value "
                    raise RuntimeError(msg)
            self._throttle()
            self._last_job_submitted = _time.time()
            job = self._submit_command(self._submitit_command_str)
            job.paths.move_temporary_file(pickle_path, "submitted_pickle")
            jobs.append(job)
        return jobs

    def _throttle(self) -> None:
        while _time.time() - self._last_job_submitted < self._throttling:
            _time.sleep(self._throttling)

    @property
    def _submitit_command_str(self) -> str:
        # this is the command submitted from "submit" to "_submit_command"
        return "dummy"

    def _submit_command(self, command: str) -> Job[tp.Any]:
        """Submits a command to the cluster
        It is recommended not to use this function since the Job instance assumes pickle
        files will be created at the end of the job, and hence it will not work correctly.
        You may use a CommandFunction as argument to the submit function instead. The only
        problem with this latter solution is that stdout is buffered, and you will therefore
        not be able to monitor the logs in real time.

        Parameters
        ----------
        command: str
            a command string

        Returns
        -------
        Job
            A Job instance, providing access to the crun job information.
            Since it has no output, some methods will not be efficient
        """
        tmp_uuid = uuid.uuid4().hex
        submission_file_path = (
            utils.JobPaths.get_first_id_independent_folder(self.folder) / f".submission_file_{tmp_uuid}.sh"
        )
        with submission_file_path.open("w") as f:
            f.write(self._make_submission_file_text(command, tmp_uuid))
        command_list = self._make_submission_command(submission_file_path)
        # run
        output = utils.CommandFunction(command_list, verbose=False)()  # explicit errors
        job_id = self._get_job_id_from_submission_command(output)
        tasks_ids = list(range(self._num_tasks()))
        job: Job[tp.Any] = self.job_class(folder=self.folder, job_id=job_id, tasks=tasks_ids)
        job.paths.move_temporary_file(submission_file_path, "submission_file", keep_as_symlink=True)
        self._write_job_id(job.job_id, tmp_uuid)
        self._set_job_permissions(job.paths.folder)
        return job

    def _write_job_id(self, job_id: str, uid: str) -> None:
        """Write the job id in a file named {job-independent folder}/parent_job_id_{uid}.
        This can create files read by plugins to get the job_id of the parent job
        """

    @abc.abstractmethod
    def _num_tasks(self) -> int:
        """Returns the number of tasks associated to the job"""
        raise NotImplementedError

    @abc.abstractmethod
    def _make_submission_file_text(self, command: str, uid: str) -> str:
        """Creates the text of a file which will be created and run
        for the submission (for slurm, this is sbatch file).
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _make_submission_command(self, submission_file_path: Path) -> tp.List[str]:
        """Create the submission command."""
        raise NotImplementedError

    @staticmethod
    @abc.abstractmethod
    def _get_job_id_from_submission_command(string: tp.Union[bytes, str]) -> str:
        """Recover the job id from the output of the submission command."""
        raise NotImplementedError

    @staticmethod
    def _set_job_permissions(folder: Path) -> None:
        pass
