# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import sys
from pathlib import Path

import pytest

from ..local import debug
from ..slurm import test_slurm
from . import auto


def test_slurm_executor(tmp_path: Path, monkeypatch) -> None:
    monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"})
    with test_slurm.mocked_slurm():
        executor = auto.AutoExecutor(folder=tmp_path)
    assert executor.cluster == "slurm"

    # local_xxx parameter is ignored
    executor.update_parameters(mem_gb=2, name="machin", debug_blabla="blublu")
    params = executor._executor.parameters
    assert params == {"mem": "2GB", "job_name": "machin"}

    # shared parameter with wrong type
    with pytest.raises(AssertionError):
        executor.update_parameters(mem_gb="2.0GB")  # should be int
    # unknown shared parameter
    with pytest.raises(NameError):
        executor.update_parameters(blublu=2.0)
    # unknown slurm parameter
    with pytest.raises(NameError):
        executor.update_parameters(slurm_host_filter="blublu")
    # check that error message contains all
    with pytest.raises(NameError, match=r"debug_blublu.*\n.*local_num_threads"):
        executor.update_parameters(debug_blublu=2.0, local_num_threads=4)


def test_local_executor(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        executor = auto.AutoExecutor(folder=tmp_path, cluster="local")
    assert executor.cluster == "local"


def test_max_pickle_size_gb_in_auto(tmp_path: Path) -> None:
    ex = auto.AutoExecutor(folder=tmp_path, cluster="local", local_max_pickle_size_gb=0.12)
    assert ex._executor.max_pickle_size_gb == 0.12  # type: ignore


def test_python_executor(tmp_path: Path) -> None:
    executor = auto.AutoExecutor(folder=tmp_path, cluster="local", local_python=sys.executable)
    job = executor.submit(lambda: 12)
    assert job.result() == 12


def test_executor_argument(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        executor = auto.AutoExecutor(folder=tmp_path, slurm_max_num_timeout=22)
    assert getattr(executor._executor, "max_num_timeout", None) == 22

    # Local executor
    executor = auto.AutoExecutor(folder=tmp_path, cluster="local", slurm_max_num_timeout=22)
    assert getattr(executor._executor, "max_num_timeout", None) != 22


def test_executor_unknown_argument(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        with pytest.raises(TypeError):
            auto.AutoExecutor(folder=tmp_path, slurm_foobar=22)


def test_executor_deprecated_arguments(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        with pytest.warns(UserWarning, match="slurm_max_num_timeout"):
            auto.AutoExecutor(folder=tmp_path, max_num_timeout=22)


def test_deprecated_argument(tmp_path: Path, monkeypatch) -> None:
    monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"})
    with test_slurm.mocked_slurm():
        executor = auto.AutoExecutor(folder=tmp_path)
    assert executor.cluster == "slurm"

    # debug 'blabla' parameter is ignored
    with pytest.warns(UserWarning, match=r"blabla.*debug_blabla"):
        executor.update_parameters(mem_gb=2, blabla="blublu")


def test_overriden_arguments(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        slurm_ex = auto.AutoExecutor(folder=tmp_path, cluster="slurm")

    slurm_ex.update_parameters(
        timeout_min=60, slurm_timeout_min=120, tasks_per_node=2, slurm_ntasks_per_node=3
    )
    slurm_params = slurm_ex._executor.parameters
    # slurm use time
    assert slurm_params == {"time": 120, "ntasks_per_node": 3}

    # others use timeout_min
    local_ex = auto.AutoExecutor(folder=tmp_path, cluster="local")
    local_ex.update_parameters(timeout_min=60, slurm_time=120)


def test_auto_batch_watcher(tmp_path: Path) -> None:
    with test_slurm.mocked_slurm():
        executor = auto.AutoExecutor(folder=tmp_path)
        with executor.batch():
            job = executor.submit(print, "hi")
        assert not job.done()


def test_redirect_stdout_stderr(executor) -> None:
    def log_to_stderr_and_stdout():
        print("hello")
        print("world", file=sys.stderr)

    executor.update_parameters(stderr_to_stdout=True)
    job = executor.submit(log_to_stderr_and_stdout)
    job.wait()
    assert job.stderr() is None
    stdout = job.stdout()
    assert "hello" in stdout
    assert "world" in stdout

    executor.update_parameters(stderr_to_stdout=False)
    job = executor.submit(log_to_stderr_and_stdout)
    job.wait()
    assert "world" in job.stderr()
    assert "hello" in job.stdout()
