Instead of parsing the arguments twice, wrap unittest.main and inject our custom argument into the existing parser.
371 lines
13 KiB
Python
371 lines
13 KiB
Python
from pathlib import Path
|
||
from typing import Any
|
||
from unittest import TestCase
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
import shutil
|
||
|
||
import unittest
|
||
import tempfile
|
||
import subprocess
|
||
import itertools
|
||
import argparse
|
||
|
||
|
||
class GrayTest:
|
||
inputs: Path = None # directory of the input files
|
||
reference: Path = None # directory of the reference outputs
|
||
candidate: Path = None # directory of the candidate outputs
|
||
|
||
# command line options
|
||
options: argparse.Namespace = argparse.Namespace()
|
||
|
||
# output tables to save
|
||
tables: list[int] = [4, 7, 8, 9, 48, 56, 33, 70, 71]
|
||
|
||
# Extra parameters to pass to gray
|
||
gray_params: dict[str, Any] = {}
|
||
|
||
@classmethod
|
||
def setUpClass(cls):
|
||
'''
|
||
Sets up the test case
|
||
'''
|
||
# directory of the test case
|
||
base = Path().joinpath(*cls.__module__.split('.'))
|
||
|
||
if cls.inputs is None:
|
||
cls.inputs = base / 'inputs'
|
||
if cls.reference is None:
|
||
cls.reference = base / 'outputs'
|
||
|
||
# temporary directory holding the candidate outputs
|
||
cls._tempdir = tempfile.mkdtemp(prefix=f'gray-test-{base.name}.')
|
||
cls.candidate = Path(cls._tempdir)
|
||
|
||
# replace reference with candidate
|
||
if cls.options.update:
|
||
print()
|
||
print('Setting new reference for ' + cls.__module__)
|
||
cls.candidate = cls.reference
|
||
|
||
# run gray to generate the candidate outputs
|
||
proc = run_gray(cls.inputs, cls.candidate, params=cls.gray_params,
|
||
binary=cls.options.binary, tables=cls.tables)
|
||
|
||
# 0: all good, 1: input errors, >1: simulation errors
|
||
assert proc.returncode != 1, 'gray failed with exit code 1'
|
||
|
||
# store the stderr for manual inspection
|
||
with open(str(cls.candidate / 'log'), 'w') as log:
|
||
log.write(proc.stderr)
|
||
|
||
@classmethod
|
||
def tearDownClass(cls):
|
||
'''
|
||
Clean up after all tests
|
||
'''
|
||
# remove temporary directory
|
||
if cls._passed or not cls.options.keep_failed:
|
||
shutil.rmtree(cls._tempdir)
|
||
else:
|
||
print()
|
||
print('Some tests failed: preserving outputs in', cls._tempdir)
|
||
|
||
def run(self, result: unittest.runner.TextTestResult):
|
||
'''
|
||
Override to store the test results for tearDownClass
|
||
'''
|
||
TestCase.run(self, result)
|
||
self.__class__._passed = result.failures == []
|
||
|
||
def test_eccd_values(self):
|
||
'''
|
||
Comparing the ECCD values
|
||
'''
|
||
try:
|
||
ref = load_table(self.reference / 'summary.7.txt')
|
||
cand = load_table(self.candidate / 'summary.7.txt')
|
||
except FileNotFoundError:
|
||
raise unittest.SkipTest("ECCD results not available")
|
||
|
||
# precision as number of decimal places
|
||
prec = defaultdict(lambda: 3, [
|
||
('dPdV_peak', -2), ('dPdV_max', -2),
|
||
('J_φ_peak', -2), ('J_φ_max', -2),
|
||
('s_max', -1), ('χ', -1), ('ψ', -1),
|
||
])
|
||
|
||
for val in ref.dtype.names:
|
||
with self.subTest(value=val):
|
||
for i, ray in enumerate(ref['index_rt']):
|
||
ref_val = ref[val][i]
|
||
cand_val = cand[val][i]
|
||
msg = f"{val} changed (ray {int(ray)})"
|
||
self.assertAlmostEqual(ref_val, cand_val, prec[val],
|
||
msg=msg)
|
||
|
||
def test_eccd_profiles(self):
|
||
'''
|
||
Comparing the ECCD radial profiles
|
||
'''
|
||
from scipy.stats import wasserstein_distance as emd
|
||
import numpy as np
|
||
|
||
try:
|
||
ref = load_table(self.reference / 'ec-profiles.48.txt')
|
||
cand = load_table(self.candidate / 'ec-profiles.48.txt')
|
||
except FileNotFoundError:
|
||
raise unittest.SkipTest("ECCD profiles not available")
|
||
|
||
beams = np.unique(ref['index_rt'])
|
||
for index_rt, val in itertools.product(beams, ['J_cd', 'dPdV', 'J_φ']):
|
||
ref_beam = ref[ref['index_rt'] == index_rt]
|
||
cand_beam = cand[cand['index_rt'] == index_rt]
|
||
|
||
# skip if both empty
|
||
if np.all(ref_beam[val] == 0) and np.all(cand_beam[val] == 0):
|
||
continue
|
||
|
||
# compare with the earth mover's distance
|
||
with self.subTest(profile=val, beam=index_rt):
|
||
y1 = abs(ref_beam[val]) / np.sum(abs(ref_beam[val]))
|
||
y2 = abs(cand_beam[val]) / np.sum(abs(cand_beam[val]))
|
||
dist = emd(ref_beam['ρ_t'], cand_beam['ρ_t'], y1, y2)
|
||
self.assertLess(dist, 0.001, f'{val} profile changed')
|
||
|
||
if self.options.visual:
|
||
for index_rt in beams:
|
||
ref_beam = ref[ref['index_rt'] == index_rt]
|
||
cand_beam = cand[cand['index_rt'] == index_rt]
|
||
|
||
fig, axes = plt.subplots(3, 1, sharex=True)
|
||
fig.suptitle(self.__module__ + '.test_ec_profiles')
|
||
|
||
axes[0].set_title(f'beam {int(index_rt)}', loc='right')
|
||
axes[0].set_ylabel('$J_\\text{cd}$')
|
||
axes[0].plot(ref_beam['ρ_t'], ref_beam['J_cd'],
|
||
c='xkcd:red', label='reference')
|
||
axes[0].plot(cand_beam['ρ_t'], cand_beam['J_cd'],
|
||
c='xkcd:green', ls='-.', label='candidate')
|
||
axes[0].legend()
|
||
|
||
axes[1].set_ylabel('$dP/dV$')
|
||
axes[1].plot(ref_beam['ρ_t'], ref_beam['dPdV'],
|
||
c='xkcd:red')
|
||
axes[1].plot(cand_beam['ρ_t'], cand_beam['dPdV'],
|
||
c='xkcd:green', ls='-.')
|
||
|
||
axes[2].set_xlabel('$ρ_t$')
|
||
axes[2].set_ylabel('$J_φ$')
|
||
axes[2].plot(ref_beam['ρ_t'], ref_beam['J_φ'],
|
||
c='xkcd:red')
|
||
axes[2].plot(cand_beam['ρ_t'], cand_beam['J_φ'],
|
||
c='xkcd:green', ls='-.')
|
||
plt.show()
|
||
|
||
def test_flux_averages(self):
|
||
'''
|
||
Comparing the flux averages table
|
||
'''
|
||
try:
|
||
ref = load_table(self.reference / 'flux-averages.56.txt')
|
||
cand = load_table(self.candidate / 'flux-averages.56.txt')
|
||
except FileNotFoundError:
|
||
raise unittest.SkipTest("Flux averages table not available")
|
||
|
||
# precision as number of decimal places
|
||
prec = defaultdict(lambda: 3, [
|
||
('J_φ_avg', -3), ('I_pl', -3),
|
||
('area', 1), ('vol', 0),
|
||
('B_avg', 1), ('B_max', 1), ('B_min', 1),
|
||
])
|
||
|
||
for col in ref.dtype.names:
|
||
with self.subTest(value=col):
|
||
for row in range(ref.size):
|
||
ref_val = ref[col][row]
|
||
cand_val = cand[col][row]
|
||
line = row + 23
|
||
self.assertAlmostEqual(ref_val, cand_val, prec[col],
|
||
msg=f"{col} at line {line} changed")
|
||
|
||
if self.options.visual:
|
||
fig, axes = plt.subplots(4, 3, tight_layout=True)
|
||
fig.suptitle(self.__module__ + '.test_flux_averages')
|
||
|
||
for ax, col in zip(axes.flatten(), ref.dtype.names[2:]):
|
||
ax.set_xlabel('$ρ_p$')
|
||
ax.set_ylabel(col)
|
||
ax.plot(ref['ρ_p'], ref[col], c='xkcd:red')
|
||
ax.plot(cand['ρ_p'], cand[col], c='xkcd:green', ls='-.')
|
||
|
||
axes[3, 2].axis('off')
|
||
axes[3, 2].plot(np.nan, np.nan, c='xkcd:red', label='reference')
|
||
axes[3, 2].plot(np.nan, np.nan, c='xkcd:green', label='candidate')
|
||
axes[3, 2].legend()
|
||
plt.show()
|
||
|
||
def test_final_position(self):
|
||
'''
|
||
Comparing the final position of the central ray
|
||
'''
|
||
ref = load_table(self.reference / 'central-ray.4.txt')
|
||
cand = load_table(self.candidate / 'central-ray.4.txt')
|
||
|
||
# coordinates
|
||
self.assertAlmostEqual(ref['R'][-1], cand['R'][-1], 1)
|
||
self.assertAlmostEqual(ref['z'][-1], cand['z'][-1], 1)
|
||
self.assertAlmostEqual(ref['φ'][-1], cand['φ'][-1], 2)
|
||
|
||
# optical path length
|
||
self.assertAlmostEqual(ref['s'][-1], cand['s'][-1], 1)
|
||
|
||
def test_final_direction(self):
|
||
'''
|
||
Comparing the final direction of the central ray
|
||
'''
|
||
ref = load_table(self.reference / 'central-ray.4.txt')
|
||
cand = load_table(self.candidate / 'central-ray.4.txt')
|
||
|
||
self.assertAlmostEqual(ref['N_⊥'][-1], cand['N_⊥'][-1], 1)
|
||
self.assertAlmostEqual(ref['N_∥'][-1], cand['N_∥'][-1], 1)
|
||
|
||
def test_beam_shape(self):
|
||
'''
|
||
Comparing the final beam shape
|
||
'''
|
||
try:
|
||
ref = load_table(self.reference / 'beam-shape-final.9.txt')
|
||
cand = load_table(self.candidate / 'beam-shape-final.9.txt')
|
||
except FileNotFoundError:
|
||
raise unittest.SkipTest("Beam shape info not available")
|
||
|
||
if self.options.visual:
|
||
plt.subplot(aspect='equal')
|
||
plt.title(self.__module__ + '.test_beam_shape')
|
||
plt.xlabel('$x$ / cm')
|
||
plt.ylabel('$y$ / cm')
|
||
plt.scatter(ref['x'], ref['y'], c='red',
|
||
marker='_', label='reference')
|
||
plt.scatter(cand['x'], cand['y'], c='green',
|
||
alpha=0.6, marker='+', label='candidate')
|
||
plt.legend()
|
||
plt.show()
|
||
|
||
for ref, cand in zip(ref, cand):
|
||
with self.subTest(ray=(int(ref['j']), int(ref['k']))):
|
||
self.assertAlmostEqual(ref['x'], cand['x'], 1)
|
||
self.assertAlmostEqual(ref['y'], cand['y'], 1)
|
||
|
||
def test_error_biased(self):
|
||
'''
|
||
Test for a proportionality between Λ and any of X, Y, N∥
|
||
'''
|
||
|
||
data = load_table(self.candidate / 'central-ray.4.txt')
|
||
|
||
# restrict to within the plasma, half of the first pass
|
||
in_plasma = data['X'] > 0
|
||
first_pass = data['index_rt'] == data['index_rt'].min()
|
||
data = data[in_plasma & first_pass]
|
||
data = data[:int(data.size // 2)]
|
||
|
||
if data.size < 2:
|
||
self.skipTest("There is no plasma")
|
||
|
||
if self.options.visual:
|
||
left = plt.subplot()
|
||
plt.title(self.__module__ + '.test_error_biased')
|
||
left.set_xlabel('$s$ / cm')
|
||
left.set_ylabel('$Λ$', color='xkcd:ocean blue')
|
||
left.tick_params(axis='y', labelcolor='xkcd:ocean blue')
|
||
left.plot(data['s'], data['Λ_r'], color='xkcd:ocean blue')
|
||
|
||
right1 = left.twinx()
|
||
right1.set_ylabel('$X$', color='xkcd:orange')
|
||
right1.tick_params(axis='y', labelcolor='xkcd:orange')
|
||
right1.plot(data['s'], data['X'], color='xkcd:orange')
|
||
|
||
right2 = left.twinx()
|
||
right2.set_ylabel('$Y$', color='xkcd:vermillion')
|
||
right2.tick_params(axis='y', labelcolor='xkcd:vermillion')
|
||
right2.plot(data['s'], data['Y'], color='xkcd:vermillion')
|
||
right2.spines["right"].set_position(("axes", 1.1))
|
||
|
||
right3 = left.twinx()
|
||
right3.set_ylabel('$N_∥$', color='xkcd:green')
|
||
right3.tick_params(axis='y', labelcolor='xkcd:green')
|
||
right3.spines["right"].set_position(("axes", 1.2))
|
||
right3.plot(data['s'], data['N_∥'], color='xkcd:green')
|
||
|
||
plt.subplots_adjust(right=0.78)
|
||
plt.show()
|
||
|
||
err = data['Λ_r'].var() / 10
|
||
self.assertGreater(err, 0, msg="Λ is exactly constant")
|
||
|
||
for var in ['X', 'Y', 'N_⊥']:
|
||
# Minimise the χ²(k) = |(Λ_r - k⋅var) / err|² / (n - 1)
|
||
# The solution is simply: k = (Λ⋅var)/var⋅var
|
||
k = np.dot(data['Λ_r'], data[var]) / np.linalg.norm(data[var])**2
|
||
with self.subTest(var=var):
|
||
res = (data['Λ_r'] - k*data[var]) / err
|
||
χ2 = np.linalg.norm(res)**2 / (data.size - 1)
|
||
self.assertGreater(χ2, 1)
|
||
|
||
|
||
def get_basedir(module: str) -> Path:
|
||
"""
|
||
Given a module name (es. tests.03-TCV) returns its
|
||
base directory as a path (es. tests/03-TCV).
|
||
"""
|
||
return Path().joinpath(*module.split('.'))
|
||
|
||
|
||
def run_gray(inputs: Path, outputs: Path,
|
||
# extra gray parameters
|
||
params: dict[str, Any] = {},
|
||
# which tables to generate
|
||
tables: list[int] = [],
|
||
# which gray binary to use
|
||
binary: str = 'gray',
|
||
# extra options
|
||
options: [str] = []
|
||
) -> subprocess.CompletedProcess:
|
||
'''
|
||
Runs gray on the inputs from the `inputs` directory and storing the results
|
||
in the `outputs` directory.
|
||
'''
|
||
outputs.mkdir(exist_ok=True, parents=True)
|
||
|
||
params = [['-g', f'{k}={v}'] for k, v in params.items()]
|
||
args = [
|
||
binary,
|
||
'-c', str(inputs / 'gray.ini'),
|
||
'-t', ','.join(map(str, tables)),
|
||
'-o', str(outputs),
|
||
'-v'
|
||
] + list(itertools.chain(*params)) + options
|
||
proc = subprocess.run(args, capture_output=True, text=True)
|
||
|
||
print()
|
||
if proc.returncode != 0:
|
||
# show the log on errors
|
||
print(f'Errors occurred (exit status {proc.returncode}), showing log:')
|
||
print(*proc.args)
|
||
print(proc.stderr)
|
||
print(proc.stdout)
|
||
return proc
|
||
|
||
|
||
def load_table(fname: Path) -> np.array:
|
||
'''
|
||
Loads a GRAY output file as a structured numpy array
|
||
(columns are named as in the file header)
|
||
'''
|
||
return np.genfromtxt(fname, names=True, skip_header=21, ndmin=1)
|