# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import logging
import re
import typing as tp
from pathlib import Path

import pkg_resources
import pytest

from . import core, plugins
from .job_environment import JobEnvironment


@pytest.mark.parametrize("env", plugins.get_job_environments().values())
def test_env(env: JobEnvironment) -> None:
    assert isinstance(env, JobEnvironment)
    # We are not inside a submitit job
    assert not env.activated()
    assert type(env)._requeue is not JobEnvironment._requeue, "_requeue need to be overridden"


@pytest.mark.parametrize("ex", plugins.get_executors().values())
def test_executors(ex: tp.Type[core.Executor]) -> None:
    assert isinstance(ex, type)
    assert issubclass(ex, core.Executor)
    assert ex.affinity() >= -1


def test_finds_default_environments() -> None:
    envs = plugins.get_job_environments()
    assert len(envs) >= 3
    assert "slurm" in envs
    assert "local" in envs
    assert "debug" in envs


def test_finds_default_executors() -> None:
    ex = plugins.get_executors()
    assert len(ex) >= 3
    assert "slurm" in ex
    assert "local" in ex
    assert "debug" in ex


def test_job_environment_works(monkeypatch):
    monkeypatch.setenv("_TEST_CLUSTER_", "slurm")
    env = plugins.get_job_environment()
    assert env.cluster == "slurm"
    assert type(env).__name__ == "SlurmJobEnvironment"

    env2 = JobEnvironment()
    assert env2.cluster == "slurm"
    assert type(env2).__name__ == "SlurmJobEnvironment"


def test_job_environment_raises_outside_of_job() -> None:
    with pytest.raises(RuntimeError, match=r"which environment.*slurm.*local.*debug"):
        plugins.get_job_environment()


class PluginCreator:
    def __init__(self, tmp_path: Path, monkeypatch):
        self.tmp_path = tmp_path
        self.monkeypatch = monkeypatch

    def add_plugin(self, name: str, entry_points: str, init: str):
        plugin = self.tmp_path / name
        plugin.mkdir(mode=0o777)
        plugin_egg = plugin.with_suffix(".egg-info")
        plugin_egg.mkdir(mode=0o777)

        (plugin_egg / "entry_points.txt").write_text(entry_points)
        (plugin / "__init__.py").write_text(init)

        # also fix pkg_resources since it already has loaded old packages in other tests.
        working_set = pkg_resources.WorkingSet([str(self.tmp_path)])
        self.monkeypatch.setattr(pkg_resources, "iter_entry_points", working_set.iter_entry_points)

    def __enter__(self) -> None:
        _clear_plugin_cache()
        self.monkeypatch.syspath_prepend(self.tmp_path)

    def __exit__(self, *exception: tp.Any) -> None:
        _clear_plugin_cache()


def _clear_plugin_cache() -> None:
    plugins._get_plugins.cache_clear()
    plugins.get_executors.cache_clear()


@pytest.fixture(name="plugin_creator")
def _plugin_creator(tmp_path: Path, monkeypatch) -> tp.Iterator[PluginCreator]:
    creator = PluginCreator(tmp_path, monkeypatch)
    with creator:
        yield creator


def test_find_good_plugin(plugin_creator: PluginCreator) -> None:
    plugin_creator.add_plugin(
        "submitit_good",
        entry_points="""[submitit]
executor = submitit_good:GoodExecutor
job_environment = submitit_good:GoodJobEnvironment
unsupported_key = submitit_good:SomethingElse
""",
        init="""
import submitit

class GoodExecutor(submitit.Executor):
    pass

class GoodJobEnvironment:
    pass
""",
    )

    executors = plugins.get_executors().keys()
    # Only the plugins declared with plugin_creator are visible.
    assert set(executors) == {"good", "slurm", "local", "debug"}


def test_skip_bad_plugin(caplog, plugin_creator: PluginCreator) -> None:
    caplog.set_level(logging.WARNING, logger="submitit")
    plugin_creator.add_plugin(
        "submitit_bad",
        entry_points="""[submitit]
executor = submitit_bad:NonExisitingExecutor
job_environment = submitit_bad:BadEnvironment
unsupported_key = submitit_bad:SomethingElse
""",
        init="""
import submitit

class BadEnvironment:
    name = "bad"

    def __init__(self):
        raise Exception("this is a bad environment")
""",
    )

    executors = plugins.get_executors().keys()
    assert {"slurm", "local", "debug"} == set(executors)
    assert "bad" not in executors
    expected = [
        (logging.ERROR, r"'submitit_bad'.*no attribute 'NonExisitingExecutor'"),
        (logging.ERROR, r"'submitit_bad'.*this is a bad environment"),
        (logging.WARNING, "unsupported_key = submitit_bad:SomethingElse"),
    ]
    assert len(caplog.records) == len(expected)
    for record, ex_record in zip(caplog.records, expected):
        assert record.name == "submitit"
        assert record.levelno == ex_record[0]
        assert re.search(ex_record[1], record.getMessage())
