gray/tests/__init__.py

270 lines
9.0 KiB
Python
Raw Normal View History

2023-12-18 00:52:11 +01:00
from pathlib import Path
from typing import Any
from unittest import TestCase
import numpy as np
import shutil
import unittest
import tempfile
import subprocess
import itertools
import argparse
class GrayTest:
inputs: Path = None # directory of the input files
reference: Path = None # directory of the reference outputs
candidate: Path = None # directory of the candidate outputs
# Extra parameters to pass to gray
gray_params: dict[str, Any] = {}
@classmethod
def setUpClass(cls):
'''
Sets up the test case
'''
# directory of the test case
base = Path().joinpath(*cls.__module__.split('.'))
if cls.inputs is None:
cls.inputs = base / 'inputs'
if cls.reference is None:
cls.reference = base / 'outputs'
# temporary directory holding the candidate outputs
cls._tempdir = tempfile.mkdtemp(prefix=f'gray-test-{base.name}.')
cls.candidate = Path(cls._tempdir)
# replace reference with candidate
if options.update:
print()
print('Setting new reference for ' + cls.__module__)
cls.candidate = cls.reference
# run gray to generate the candidate outputs
proc = run_gray(cls.inputs, cls.candidate, params=cls.gray_params,
binary=options.binary)
# 0: all good, 1: input errors, >1: simulation errors
assert proc.returncode != 1, 'gray failed with exit code 1'
2023-12-18 00:52:11 +01:00
# store the stderr for manual inspection
with open(str(cls.candidate / 'log'), 'w') as log:
log.write(proc.stderr)
@classmethod
def tearDownClass(cls):
'''
Clean up after all tests
'''
# remove temporary directory
if cls._passed or not options.keep_failed:
shutil.rmtree(cls._tempdir)
else:
print()
print('Some tests failed: preserving outputs in', cls._tempdir)
def run(self, result: unittest.runner.TextTestResult):
'''
Override to store the test results for tearDownClass
'''
TestCase.run(self, result)
self.__class__._passed = result.failures == []
def test_eccd_values(self):
'''
Comparing the ECCD values
'''
from collections import defaultdict
2024-05-15 08:56:34 +02:00
try:
ref = load_table(self.reference / 'summary.7.txt')
cand = load_table(self.candidate / 'summary.7.txt')
except FileNotFoundError:
2024-01-30 12:34:45 +01:00
raise unittest.SkipTest("ECCD results not available")
2023-12-18 00:52:11 +01:00
# precision as number of decimal places
prec = defaultdict(lambda: 3, [
2024-05-15 08:56:34 +02:00
('dPdV_peak', -2), ('dPdV_max', -2),
('J_φ_peak', -2), ('J_φ_max', -2),
('s_max', -1), ('χ', -1), ('ψ', -1),
2023-12-18 00:52:11 +01:00
])
for val in ref.dtype.names:
with self.subTest(value=val):
for i, ray in enumerate(ref['index_rt']):
self.assertAlmostEqual(
ref[val][i], cand[val][i], prec[val],
msg=f"{val} changed (ray {int(ray)})")
def test_final_position(self):
'''
Comparing the final position of the central ray
'''
2024-05-15 08:56:34 +02:00
ref = load_table(self.reference / 'central-ray.4.txt')
cand = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
# coordinates
self.assertAlmostEqual(ref['R'][-1], cand['R'][-1], 1)
self.assertAlmostEqual(ref['z'][-1], cand['z'][-1], 1)
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['φ'][-1], cand['φ'][-1], 2)
2023-12-18 00:52:11 +01:00
# optical path length
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['s'][-1], cand['s'][-1], 1)
2023-12-18 00:52:11 +01:00
def test_final_direction(self):
'''
Comparing the final direction of the central ray
'''
2024-05-15 08:56:34 +02:00
ref = load_table(self.reference / 'central-ray.4.txt')
cand = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['N_⊥'][-1], cand['N_⊥'][-1], 1)
self.assertAlmostEqual(ref['N_∥'][-1], cand['N_∥'][-1], 1)
2023-12-18 00:52:11 +01:00
def test_beam_shape(self):
'''
Comparing the final beam shape
'''
2024-05-15 08:56:34 +02:00
try:
ref = load_table(self.reference / 'beam-shape-final.9.txt')
cand = load_table(self.candidate / 'beam-shape-final.9.txt')
except FileNotFoundError:
2023-12-18 00:52:11 +01:00
raise unittest.SkipTest("Beam shape info not available")
if options.visual:
import matplotlib.pyplot as plt
plt.subplot(aspect='equal')
plt.title(self.__module__ + '.test_beam_shape')
plt.xlabel('$x$ / cm')
plt.ylabel('$y$ / cm')
2024-05-15 08:56:34 +02:00
plt.scatter(ref['x'], ref['y'], c='red',
2023-12-18 00:52:11 +01:00
marker='_', label='reference')
2024-05-15 08:56:34 +02:00
plt.scatter(cand['x'], cand['y'], c='green',
2023-12-18 00:52:11 +01:00
alpha=0.6, marker='+', label='candidate')
plt.legend()
plt.show()
for ref, cand in zip(ref, cand):
with self.subTest(ray=(int(ref['j']), int(ref['k']))):
2024-05-15 08:56:34 +02:00
self.assertAlmostEqual(ref['x'], cand['x'], 1)
self.assertAlmostEqual(ref['y'], cand['y'], 1)
2023-12-18 00:52:11 +01:00
def test_error_biased(self):
'''
Test for a proportionality between Λ and any of X, Y, N∥
'''
2024-05-15 08:56:34 +02:00
data = load_table(self.candidate / 'central-ray.4.txt')
2023-12-18 00:52:11 +01:00
# restrict to within the plasma, half of the first pass
2024-05-15 08:56:34 +02:00
in_plasma = data['X'] > 0
2023-12-18 00:52:11 +01:00
first_pass = data['index_rt'] == data['index_rt'].min()
data = data[in_plasma & first_pass]
data = data[:int(data.size // 2)]
if data.size < 2:
self.skipTest("There is no plasma")
if options.visual:
import matplotlib.pyplot as plt
left = plt.subplot()
plt.title(self.__module__ + '.test_error_biased')
left.set_xlabel('$s$ / cm')
left.set_ylabel('$Λ$', color='xkcd:ocean blue')
left.tick_params(axis='y', labelcolor='xkcd:ocean blue')
2024-05-15 08:56:34 +02:00
left.plot(data['s'], data['Λ_r'], color='xkcd:ocean blue')
2023-12-18 00:52:11 +01:00
right1 = left.twinx()
right1.set_ylabel('$X$', color='xkcd:orange')
right1.tick_params(axis='y', labelcolor='xkcd:orange')
2024-05-15 08:56:34 +02:00
right1.plot(data['s'], data['X'], color='xkcd:orange')
2023-12-18 00:52:11 +01:00
right2 = left.twinx()
right2.set_ylabel('$Y$', color='xkcd:vermillion')
right2.tick_params(axis='y', labelcolor='xkcd:vermillion')
2024-05-15 08:56:34 +02:00
right2.plot(data['s'], data['Y'], color='xkcd:vermillion')
2023-12-18 00:52:11 +01:00
right2.spines["right"].set_position(("axes", 1.1))
right3 = left.twinx()
right3.set_ylabel('$N_∥$', color='xkcd:green')
right3.tick_params(axis='y', labelcolor='xkcd:green')
right3.spines["right"].set_position(("axes", 1.2))
2024-05-15 08:56:34 +02:00
right3.plot(data['s'], data['N_∥'], color='xkcd:green')
2023-12-18 00:52:11 +01:00
plt.subplots_adjust(right=0.78)
plt.show()
2024-05-15 08:56:34 +02:00
err = data['Λ_r'].var() / 10
2023-12-18 00:52:11 +01:00
self.assertGreater(err, 0, msg="Λ is exactly constant")
2024-05-15 08:56:34 +02:00
for var in ['X', 'Y', 'N_⊥']:
# Minimise the χ²(k) = |(Λ_r - k⋅var) / err|² / (n - 1)
# The solution is simply: k = (Λ⋅var)/var⋅var
k = np.dot(data['Λ_r'], data[var]) / np.linalg.norm(data[var])**2
2023-12-18 00:52:11 +01:00
with self.subTest(var=var):
res = (data['Λ_r'] - k*data[var]) / err
χ2 = np.linalg.norm(res)**2 / (data.size - 1)
self.assertGreater(χ2, 1)
2023-12-18 00:52:11 +01:00
# Command line options
options = argparse.Namespace()
def get_basedir(module: str) -> Path:
"""
Given a module name (es. tests.03-TCV) returns its
base directory as a path (es. tests/03-TCV).
"""
return Path().joinpath(*module.split('.'))
def run_gray(inputs: Path, outputs: Path,
# extra gray parameters
params: dict[str, Any] = {},
2024-05-15 08:56:34 +02:00
# which tables to generate
tables: list[int] = [4, 7, 8, 9, 48, 33, 70, 71],
2023-12-18 00:52:11 +01:00
# which gray binary to use
2024-05-16 17:46:21 +02:00
binary: str = 'gray',
# extra options
options: [str] = []
2023-12-18 00:52:11 +01:00
) -> subprocess.CompletedProcess:
'''
Runs gray on the inputs from the `inputs` directory and storing the results
in the `outputs` directory.
'''
outputs.mkdir(exist_ok=True, parents=True)
params = [['-g', f'{k}={v}'] for k, v in params.items()]
args = [
binary,
'-c', str(inputs / 'gray.ini'),
2024-05-15 08:56:34 +02:00
'-t', ','.join(map(str, tables)),
2023-12-18 00:52:11 +01:00
'-o', str(outputs),
'-v'
2024-05-16 17:46:21 +02:00
] + list(itertools.chain(*params)) + options
2023-12-18 00:52:11 +01:00
proc = subprocess.run(args, capture_output=True, text=True)
print()
2023-12-18 00:52:11 +01:00
if proc.returncode != 0:
# show the log on errors
print(f'Errors occurred (exit status {proc.returncode}), showing log:')
print(*proc.args)
2023-12-18 00:52:11 +01:00
print(proc.stderr)
print(proc.stdout)
return proc
2024-05-15 08:56:34 +02:00
def load_table(fname: Path) -> np.array:
2023-12-18 00:52:11 +01:00
'''
2024-05-15 08:56:34 +02:00
Loads a GRAY output file as a structured numpy array
2023-12-18 00:52:11 +01:00
(columns are named as in the file header)
'''
return np.genfromtxt(fname, names=True, skip_header=21, ndmin=1)