'''
Combine EC profiles from three independent beams
'''

from .. import TestCase, options, load_table, run_gray

from pathlib import Path

import unittest
import shutil
import tempfile


class Test(TestCase):
    inputs: Path     # directory of the input files
    reference: Path  # directory of the reference outputs
    candidate: Path  # directory of the candidate outputs

    @classmethod
    def setUpClass(cls):
        '''
        Sets up the test case
        '''
        # directory of the test case
        base = Path().joinpath(*cls.__module__.split('.'))
        cls.inputs = base / 'inputs'
        cls.reference = base / 'outputs'

        # temporary directory holding the candidate outputs
        cls._tempdir = tempfile.mkdtemp(prefix=f'gray-test-{base.name}.')
        cls.candidate = Path(cls._tempdir)

        # replace reference with candidate
        if options.update:
            print()
            print('Setting new reference for ' + cls.__module__)
            cls.candidate = cls.reference

        # run gray to generate the candidate outputs
        proc = run_gray(cls.inputs, cls.candidate, binary=options.binary,
                        options=['-s', cls.inputs / 'filelist.txt'])
        assert proc.returncode == 0, \
               f"gray failed with exit code {proc.returncode}"

        # store the stderr for manual inspection
        with open(str(cls.candidate / 'log'), 'w') as log:
            log.write(proc.stderr)

    @classmethod
    def tearDownClass(cls):
        '''
        Clean up after all tests
        '''
        # remove temporary directory
        if cls._passed or not options.keep_failed:
            shutil.rmtree(cls._tempdir)
        else:
            print()
            print('Some tests failed: preserving outputs in', cls._tempdir)

    def run(self, result: unittest.runner.TextTestResult):
        '''
        Override to store the test results for tearDownClass
        '''
        TestCase.run(self, result)
        self.__class__._passed = result.failures == []

    def test_eccd_values(self):
        '''
        Comparing the ECCD values
        '''
        from collections import defaultdict

        ref = load_table(self.reference / 'sum-summary.txt')
        cand = load_table(self.candidate / 'sum-summary.txt')

        # precision as number of decimal places
        prec = defaultdict(lambda: 3, [
            ('dPdV_peak', -2), ('dPdV_max', -2),
            ('J_φ_peak', -2), ('J_φ_max', -2),
        ])

        for val in ref.dtype.names:
            with self.subTest(value=val):
                self.assertAlmostEqual(ref[val], cand[val], prec[val],
                                       msg=f"{val} changed)")

    def test_ec_profiles(self):
        '''
        Comparing the EC radial profiles
        '''

        ref = load_table(self.reference / 'sum-ec-profiles.txt')
        cand = load_table(self.candidate / 'sum-ec-profiles.txt')

        # todo