gray/tests/__init__.py

275 lines
9.0 KiB
Python
Raw Normal View History

2023-12-18 00:52:11 +01:00
from pathlib import Path
from typing import Any
from unittest import TestCase
import warnings
import numpy as np
import shutil
import unittest
import tempfile
import subprocess
import itertools
import argparse
class GrayTest:
inputs: Path = None # directory of the input files
reference: Path = None # directory of the reference outputs
candidate: Path = None # directory of the candidate outputs
# Extra parameters to pass to gray
gray_params: dict[str, Any] = {}
@classmethod
def setUpClass(cls):
'''
Sets up the test case
'''
# directory of the test case
base = Path().joinpath(*cls.__module__.split('.'))
if cls.inputs is None:
cls.inputs = base / 'inputs'
if cls.reference is None:
cls.reference = base / 'outputs'
# temporary directory holding the candidate outputs
cls._tempdir = tempfile.mkdtemp(prefix=f'gray-test-{base.name}.')
cls.candidate = Path(cls._tempdir)
# replace reference with candidate
if options.update:
print()
print('Setting new reference for ' + cls.__module__)
cls.candidate = cls.reference
# run gray to generate the candidate outputs
proc = run_gray(cls.inputs, cls.candidate, params=cls.gray_params,
binary=options.binary)
assert proc.returncode == 0, \
f"gray failed with exit code {proc.returncode}"
# store the stderr for manual inspection
with open(str(cls.candidate / 'log'), 'w') as log:
log.write(proc.stderr)
@classmethod
def tearDownClass(cls):
'''
Clean up after all tests
'''
# remove temporary directory
if cls._passed or not options.keep_failed:
shutil.rmtree(cls._tempdir)
else:
print()
print('Some tests failed: preserving outputs in', cls._tempdir)
def run(self, result: unittest.runner.TextTestResult):
'''
Override to store the test results for tearDownClass
'''
TestCase.run(self, result)
self.__class__._passed = result.failures == []
def test_eccd_values(self):
'''
Comparing the ECCD values
'''
from collections import defaultdict
2024-05-15 08:56:34 +02:00
try:
ref = load_table(self.reference / 'summary.7.txt')
cand = load_table(self.candidate / 'summary.7.txt')
except FileNotFoundError:
2024-01-30 12:34:45 +01:00
raise unittest.SkipTest("ECCD results not available")
2023-12-18 00:52:11 +01:00
# precision as number of decimal places
prec = defaultdict(lambda: 3, [
2024-05-15 08:56:34 +02:00
('dPdV_peak', -2), ('dPdV_max', -2),
('J_φ_peak', -2), ('J_φ_max', -2),
('s_max', -1), ('χ', -1), ('ψ', -1),
2023-12-18 00:52:11 +01:00
])
for val in ref.dtype.names:
with self.subTest(value=val):
for i, ray in enumerate(ref['index_rt']):
self.assertAlmostEqual(
ref[val][i], cand[val][i], prec[val],
msg=f"{val} changed (ray {int(ray)})")
def test_final_position(self):
'''
Comparing the final position of the central ray
'''
2024-05-15 08:56:34 +02:00
ref = load_table(self.reference / 'central-ray.4.txt')
cand = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
# coordinates
self.assertAlmostEqual(ref['R'][-1], cand['R'][-1], 1)
self.assertAlmostEqual(ref['z'][-1], cand['z'][-1], 1)
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['φ'][-1], cand['φ'][-1], 2)
2023-12-18 00:52:11 +01:00
# optical path length
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['s'][-1], cand['s'][-1], 1)
2023-12-18 00:52:11 +01:00
def test_final_direction(self):
'''
Comparing the final direction of the central ray
'''
2024-05-15 08:56:34 +02:00
ref = load_table(self.reference / 'central-ray.4.txt')
cand = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['N_⊥'][-1], cand['N_⊥'][-1], 1)
self.assertAlmostEqual(ref['N_∥'][-1], cand['N_∥'][-1], 1)
2023-12-18 00:52:11 +01:00
def test_beam_shape(self):
'''
Comparing the final beam shape
'''
2024-05-15 08:56:34 +02:00
try:
ref = load_table(self.reference / 'beam-shape-final.9.txt')
cand = load_table(self.candidate / 'beam-shape-final.9.txt')
except FileNotFoundError:
2023-12-18 00:52:11 +01:00
raise unittest.SkipTest("Beam shape info not available")
if options.visual:
import matplotlib.pyplot as plt
plt.subplot(aspect='equal')
plt.title(self.__module__ + '.test_beam_shape')
plt.xlabel('$x$ / cm')
plt.ylabel('$y$ / cm')
2024-05-15 08:56:34 +02:00
plt.scatter(ref['x'], ref['y'], c='red',
2023-12-18 00:52:11 +01:00
marker='_', label='reference')
2024-05-15 08:56:34 +02:00
plt.scatter(cand['x'], cand['y'], c='green',
2023-12-18 00:52:11 +01:00
alpha=0.6, marker='+', label='candidate')
plt.legend()
plt.show()
for ref, cand in zip(ref, cand):
with self.subTest(ray=(int(ref['j']), int(ref['k']))):
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['x'], cand['x'], 1)
self.assertAlmostEqual(ref['y'], cand['y'], 1)
2023-12-18 00:52:11 +01:00
def test_error_biased(self):
'''
Test for a proportionality between Λ and any of X, Y, N∥
'''
2024-05-15 08:56:34 +02:00
data = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
# restrict to within the plasma, half of the first pass
2024-05-15 08:56:34 +02:00
in_plasma = data['X'] > 0
2023-12-18 00:52:11 +01:00
first_pass = data['index_rt'] == data['index_rt'].min()
data = data[in_plasma & first_pass]
data = data[:int(data.size // 2)]
if data.size < 2:
self.skipTest("There is no plasma")
if options.visual:
import matplotlib.pyplot as plt
left = plt.subplot()
plt.title(self.__module__ + '.test_error_biased')
left.set_xlabel('$s$ / cm')
left.set_ylabel('$Λ$', color='xkcd:ocean blue')
left.tick_params(axis='y', labelcolor='xkcd:ocean blue')
2024-05-15 08:56:34 +02:00
left.plot(data['s'], data['Λ_r'], color='xkcd:ocean blue')
2023-12-18 00:52:11 +01:00
right1 = left.twinx()
right1.set_ylabel('$X$', color='xkcd:orange')
right1.tick_params(axis='y', labelcolor='xkcd:orange')
2024-05-15 08:56:34 +02:00
right1.plot(data['s'], data['X'], color='xkcd:orange')
2023-12-18 00:52:11 +01:00
right2 = left.twinx()
right2.set_ylabel('$Y$', color='xkcd:vermillion')
right2.tick_params(axis='y', labelcolor='xkcd:vermillion')
2024-05-15 08:56:34 +02:00
right2.plot(data['s'], data['Y'], color='xkcd:vermillion')
2023-12-18 00:52:11 +01:00
right2.spines["right"].set_position(("axes", 1.1))
right3 = left.twinx()
right3.set_ylabel('$N_∥$', color='xkcd:green')
right3.tick_params(axis='y', labelcolor='xkcd:green')
right3.spines["right"].set_position(("axes", 1.2))
2024-05-15 08:56:34 +02:00
right3.plot(data['s'], data['N_∥'], color='xkcd:green')
2023-12-18 00:52:11 +01:00
plt.subplots_adjust(right=0.78)
plt.show()
2024-05-15 08:56:34 +02:00
err = data['Λ_r'].var() / 10
2023-12-18 00:52:11 +01:00
self.assertGreater(err, 0, msg="Λ is exactly constant")
def χ2(k, var):
'''
Reduced χ² for the curve fit: Λ(s) = k⋅var(s)
'''
2024-05-15 08:56:34 +02:00
res = (data['Λ_r'] - k*data[var]) / err
2023-12-18 00:52:11 +01:00
return np.sum(res**2) / (data.size - 1)
import scipy.optimize
2024-05-15 08:56:34 +02:00
for var in ['X', 'Y', 'N_⊥']:
2023-12-18 00:52:11 +01:00
k_best = scipy.optimize.minimize(χ2, x0=1, args=var).x[0]
with self.subTest(var=var):
self.assertGreater(χ2(k_best, var), 1)
# Command line options
options = argparse.Namespace()
def get_basedir(module: str) -> Path:
"""
Given a module name (es. tests.03-TCV) returns its
base directory as a path (es. tests/03-TCV).
"""
return Path().joinpath(*module.split('.'))
def run_gray(inputs: Path, outputs: Path,
# extra gray parameters
params: dict[str, Any] = {},
2024-05-15 08:56:34 +02:00
# which tables to generate
tables: list[int] = [4, 7, 8, 9, 48, 33, 70, 71],
2023-12-18 00:52:11 +01:00
# which gray binary to use
2024-05-16 17:46:21 +02:00
binary: str = 'gray',
# extra options
options: [str] = []
2023-12-18 00:52:11 +01:00
) -> subprocess.CompletedProcess:
'''
Runs gray on the inputs from the `inputs` directory and storing the results
in the `outputs` directory.
'''
outputs.mkdir(exist_ok=True, parents=True)
params = [['-g', f'{k}={v}'] for k, v in params.items()]
args = [
binary,
'-c', str(inputs / 'gray.ini'),
2024-05-15 08:56:34 +02:00
'-t', ','.join(map(str, tables)),
2023-12-18 00:52:11 +01:00
'-o', str(outputs),
'-v'
2024-05-16 17:46:21 +02:00
] + list(itertools.chain(*params)) + options
2023-12-18 00:52:11 +01:00
proc = subprocess.run(args, capture_output=True, text=True)
if proc.returncode != 0:
# show the log on errors
print(proc.stderr)
print(proc.stdout)
return proc
2024-05-15 08:56:34 +02:00
def load_table(fname: Path) -> np.array:
2023-12-18 00:52:11 +01:00
'''
2024-05-15 08:56:34 +02:00
Loads a GRAY output file as a structured numpy array
2023-12-18 00:52:11 +01:00
(columns are named as in the file header)
'''
# ignore warnings about empty files
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return np.genfromtxt(fname, names=True, skip_header=21, ndmin=1)