Skip to content
Snippets Groups Projects
test_GVEC_class.py 8.33 KiB
# PyTest will consider all `.py` filenames that begin with `test_` (or end with `_test`) as test files.
# Then it will run all functions with prefix `test_`, which may be inside classes with prefix `Test`, and look for `assert` statements.
# Therefore it is important to put all imports inside the function.
# Also because these test functions are located inside the `test` directory, we need to update the import location with `../src`.
# But this is not a good practice. We should install the package in development mode, and run PyTest against that instead.



def test_GVEC_class():
    """
    Test the main interface of this code, the class `GVEC`, and evaluate various MHD variables.
    """

    # ============================================================
    # Imports.
    # ============================================================

    import os
    import sys
    basedir = os.path.dirname(os.path.realpath(__file__))
    sys.path.insert(0, os.path.join(basedir, '..', 'src')) # Because we are inside './test/' directory.
    if sys.version_info[0] < 3:
        from StringIO import StringIO
    else:
        from io import StringIO

    # import logging
    from gvec_to_python.util.logger import logger
    # logger = logging.getLogger(__name__)

    import numpy as np

    from gvec_to_python import GVEC, Form, Variable

    logger.info(' ')
    logger.info('='*80)
    logger.info('Running `test_GVEC_class()`.')
    logger.info('Test the `GVEC` class and evaluate various MHD variables.')
    logger.info('='*80)
    logger.info(' ')



    # ============================================================
    # Init GVEC class.
    # ============================================================

    filepath = 'GVEC/testcases/ellipstell/'
    filepath =  os.path.join(basedir, '..', filepath) # Because we are inside './test/' directory.
    filename = 'GVEC_ellipStell_profile_update_State_0000_00010000.json'
    gvec = GVEC(filepath, filename)



    # ============================================================
    # Test various MHD variables and p-forms.
    # ============================================================

    enumerate_input_types(gvec, variable=Variable.PRESSURE, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.PRESSURE, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.PHI, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.PHI, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.CHI, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.CHI, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.DPHI, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.DPHI, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.DCHI, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.DCHI, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.IOTA, form=Form.ZERO)
    enumerate_input_types(gvec, variable=Variable.IOTA, form=Form.THREE)
    enumerate_input_types(gvec, variable=Variable.A, form=Form.PHYSICAL)
    enumerate_input_types(gvec, variable=Variable.A, form=Form.CONTRAVARIANT)
    enumerate_input_types(gvec, variable=Variable.A, form=Form.ONE)
    enumerate_input_types(gvec, variable=Variable.A, form=Form.TWO)
    enumerate_input_types(gvec, variable=Variable.B, form=Form.PHYSICAL)
    enumerate_input_types(gvec, variable=Variable.B, form=Form.CONTRAVARIANT)
    enumerate_input_types(gvec, variable=Variable.B, form=Form.ONE)
    enumerate_input_types(gvec, variable=Variable.B, form=Form.TWO)



def enumerate_input_types(gvec, variable, form, num_s=3, num_u=4, num_v=5):
    """
    Test each of scalar, sparse meshgrid, and dense meshgrid inputs.
    """

    # ============================================================
    # Imports.
    # ============================================================

    # import logging
    from gvec_to_python.util.logger import logger
    # logger = logging.getLogger(__name__)

    import numpy as np

    from gvec_to_python import GVEC, Form, Variable



    # ============================================================
    # Test grid.
    # ============================================================

    s_range = np.linspace(0, 1, num_s+2)[1:] # Skip 0.
    u_range = np.linspace(0, 1, num_u+1)
    v_range = np.linspace(0, 1, num_v+1)
    logger.info('Total number of points: {}'.format(s_range.shape[0] * u_range.shape[0] * v_range.shape[0]))
    logger.info(' ')

    s_sparse, u_sparse, v_sparse = np.meshgrid(s_range, u_range, v_range, indexing='ij', sparse=True)
    s_dense,  u_dense,  v_dense  = np.meshgrid(s_range, u_range, v_range, indexing='ij', sparse=False)

    is_vector_form = False
    if (variable == Variable.A) or (variable == Variable.B):
        is_vector_form = True



    # ============================================================
    # Test variable: Scalar inputs.
    # ============================================================

    logger.info('Test {} {}: Scalar inputs: Starts.'.format(variable, form))

    if is_vector_form:
        variable_scalar = np.zeros((3, s_range.shape[0], u_range.shape[0], v_range.shape[0]))
    else:
        variable_scalar = np.zeros((s_range.shape[0], u_range.shape[0], v_range.shape[0]))

    for s_idx, s in enumerate(s_range):
        for u_idx, u in enumerate(u_range):
            for v_idx, v in enumerate(v_range):
                if is_vector_form:
                    variable_scalar[:, s_idx, u_idx, v_idx] = gvec.get_variable(s, u, v, variable, form)
                else:
                    variable_scalar[s_idx, u_idx, v_idx] = gvec.get_variable(s, u, v, variable, form)
    logger.info('{} {} shape: {}'.format(variable, form, variable_scalar.shape,))

    logger.info('Test {} {}: Scalar inputs: Completed.'.format(variable, form))
    logger.info(' ')



    # ============================================================
    # Test variable: 1D arrays before meshgrid.
    # ============================================================

    logger.info('Test {} {}: 1D arrays before meshgrid inputs: Starts.'.format(variable, form))

    variable_range = gvec.get_variable(s_range, u_range, v_range, variable, form)
    variable_range = np.array(variable_range)
    logger.info('{} {} shape: {}'.format(variable, form, variable_range.shape,))

    logger.info('Test {} {}: 1D arrays before meshgrid inputs: Completed.'.format(variable, form))
    logger.info(' ')



    # ============================================================
    # Test variable: Sparse meshgrid.
    # ============================================================

    logger.info('Test {} {}: Sparse meshgrid inputs: Starts.'.format(variable, form))

    variable_sparse = gvec.get_variable(s_sparse, u_sparse, v_sparse, variable, form)
    variable_sparse = np.array(variable_sparse)
    logger.info('{} {} shape: {}'.format(variable, form, variable_sparse.shape,))

    logger.info('Test {} {}: Sparse meshgrid inputs: Completed.'.format(variable, form))
    logger.info(' ')



    # ============================================================
    # Test variable: Dense meshgrid.
    # ============================================================

    logger.info('Test {} {}: Dense meshgrid inputs: Starts.'.format(variable, form))

    variable_dense = gvec.get_variable(s_dense, u_dense, v_dense, variable, form)
    variable_dense = np.array(variable_dense)
    logger.info('{} {} shape: {}'.format(variable, form, variable_dense.shape,))

    logger.info('Test {} {}: Dense meshgrid inputs: Completed.'.format(variable, form))
    logger.info(' ')



    # ============================================================
    # Test variable: Check results from different techniques match.
    # ============================================================

    logger.info('Test {} {}: Compare among inputs: Starts.'.format(variable, form))

    assert np.allclose(variable_scalar, variable_dense), 'Failed to establish equivalence of scalar vs dense meshgrid inputs.'
    assert np.allclose(variable_range , variable_dense), 'Failed to establish equivalence of 1D arr vs dense meshgrid inputs.'
    assert np.allclose(variable_sparse, variable_dense), 'Failed to establish equivalence of sparse vs dense meshgrid inputs.'

    logger.info('Test {} {}: Compare among inputs: Completed.'.format(variable, form))
    logger.info(' ')



if __name__ == "__main__":

    test_GVEC_class()