Skip to content
Snippets Groups Projects
Commit d82f28b2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

demo implementation

parent ad4c6cde
Branches
Tags
1 merge request!426Be more paranoid about initializing the RNG
Pipeline #70986 failed
...@@ -46,7 +46,7 @@ def make_random_mask(): ...@@ -46,7 +46,7 @@ def make_random_mask():
if __name__ == '__main__': if __name__ == '__main__':
np.random.seed(42) ift.random.init(42)
# Choose space on which the signal field is defined # Choose space on which the signal field is defined
if len(sys.argv) == 2: if len(sys.argv) == 2:
......
from .version import __version__ from .version import __version__
from . import random
from .domains.domain import Domain from .domains.domain import Domain
from .domains.structured_domain import StructuredDomain from .domains.structured_domain import StructuredDomain
from .domains.unstructured_domain import UnstructuredDomain from .domains.unstructured_domain import UnstructuredDomain
......
...@@ -140,9 +140,9 @@ class Field(object): ...@@ -140,9 +140,9 @@ class Field(object):
Field Field
The newly created Field. The newly created Field.
""" """
from .random import Random from . import random
domain = DomainTuple.make(domain) domain = DomainTuple.make(domain)
generator_function = getattr(Random, random_type) generator_function = getattr(random, random_type)
arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs) arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
return Field(domain, arr) return Field(domain, arr)
......
...@@ -11,59 +11,74 @@ ...@@ -11,59 +11,74 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np import numpy as np
_initialized = False
class Random(object): def init(seed):
@staticmethod global _initialized
def pm1(dtype, shape): if _initialized:
if np.issubdtype(dtype, np.complexfloating): print("WARNING: re-intializing random generator")
x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype) np.random.seed(seed)
x = x[np.random.randint(4, size=shape)] else:
else: _initialized = True
x = 2*np.random.randint(2, size=shape) - 1 np.random.seed(seed)
return x.astype(dtype, copy=False)
@staticmethod def pm1(dtype, shape):
def normal(dtype, shape, mean=0., std=1.): global _initialized
if not (np.issubdtype(dtype, np.floating) or if not _initialized:
np.issubdtype(dtype, np.complexfloating)): raise RuntimeError("RNG not initialized")
raise TypeError("dtype must be float or complex") if np.issubdtype(dtype, np.complexfloating):
if not np.isscalar(mean) or not np.isscalar(std): x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
raise TypeError("mean and std must be scalars") x = x[np.random.randint(4, size=shape)]
if np.issubdtype(type(std), np.complexfloating): else:
raise TypeError("std must not be complex") x = 2*np.random.randint(2, size=shape) - 1
if ((not np.issubdtype(dtype, np.complexfloating)) and return x.astype(dtype, copy=False)
np.issubdtype(type(mean), np.complexfloating)):
raise TypeError("mean must not be complex for a real result field")
if np.issubdtype(dtype, np.complexfloating):
x = np.empty(shape, dtype=dtype)
x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
else:
x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
return x
@staticmethod def normal(dtype, shape, mean=0., std=1.):
def uniform(dtype, shape, low=0., high=1.): global _initialized
if not np.isscalar(low) or not np.isscalar(high): if not _initialized:
raise TypeError("low and high must be scalars") raise RuntimeError("RNG not initialized")
if (np.issubdtype(type(low), np.complexfloating) or if not (np.issubdtype(dtype, np.floating) or
np.issubdtype(type(high), np.complexfloating)): np.issubdtype(dtype, np.complexfloating)):
raise TypeError("low and high must not be complex") raise TypeError("dtype must be float or complex")
if np.issubdtype(dtype, np.complexfloating): if not np.isscalar(mean) or not np.isscalar(std):
x = np.empty(shape, dtype=dtype) raise TypeError("mean and std must be scalars")
x.real = np.random.uniform(low, high, shape) if np.issubdtype(type(std), np.complexfloating):
x.imag = np.random.uniform(low, high, shape) raise TypeError("std must not be complex")
elif np.issubdtype(dtype, np.integer): if ((not np.issubdtype(dtype, np.complexfloating)) and
if not (np.issubdtype(type(low), np.integer) and np.issubdtype(type(mean), np.complexfloating)):
np.issubdtype(type(high), np.integer)): raise TypeError("mean must not be complex for a real result field")
raise TypeError("low and high must be integer") if np.issubdtype(dtype, np.complexfloating):
x = np.random.randint(low, high+1, shape) x = np.empty(shape, dtype=dtype)
else: x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
x = np.random.uniform(low, high, shape) x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
return x.astype(dtype, copy=False) else:
x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
return x
def uniform(dtype, shape, low=0., high=1.):
global _initialized
if not _initialized:
raise RuntimeError("RNG not initialized")
if not np.isscalar(low) or not np.isscalar(high):
raise TypeError("low and high must be scalars")
if (np.issubdtype(type(low), np.complexfloating) or
np.issubdtype(type(high), np.complexfloating)):
raise TypeError("low and high must not be complex")
if np.issubdtype(dtype, np.complexfloating):
x = np.empty(shape, dtype=dtype)
x.real = np.random.uniform(low, high, shape)
x.imag = np.random.uniform(low, high, shape)
elif np.issubdtype(dtype, np.integer):
if not (np.issubdtype(type(low), np.integer) and
np.issubdtype(type(high), np.integer)):
raise TypeError("low and high must be integer")
x = np.random.randint(low, high+1, shape)
else:
x = np.random.uniform(low, high, shape)
return x.astype(dtype, copy=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment