Commit 9b6d9324 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 5edadf46
Pipeline #22717 passed with stage
in 4 minutes and 47 seconds
......@@ -59,7 +59,7 @@ class Field(object):
"""
def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val)
self.domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
......@@ -128,7 +128,7 @@ class Field(object):
return Field.empty(field.domain, dtype)
@staticmethod
def _parse_domain(domain, val=None):
def _infer_domain(domain, val=None):
if domain is None:
if isinstance(val, Field):
return val.domain
......
......@@ -71,8 +71,7 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
noise = self.N.diagonal().weight(-1)
mock_noise = Field.from_random(random_type="normal",
domain=self.N.domain,
dtype=noise.dtype.type)
domain=self.N.domain, dtype=noise.dtype)
mock_noise *= sqrt(noise)
mock_data = self.R(mock_signal) + mock_noise
......
......@@ -21,7 +21,6 @@ from builtins import range
from builtins import object
import numpy as np
from ..field import Field, DomainTuple
from .. import utilities
class Prober(object):
......@@ -38,15 +37,16 @@ class Prober(object):
compute_variance=False, ncpu=1):
self._domain = DomainTuple.make(domain)
self._probe_count = self._parse_probe_count(probe_count)
self._ncpu = self._parse_probe_count(ncpu)
self._random_type = self._parse_random_type(random_type)
self._probe_count = int(probe_count)
self._ncpu = int(ncpu)
if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
self._random_type = random_type
self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype)
self._uid_counter = 0
# ---Properties---
@property
def domain(self):
return self._domain
......@@ -55,22 +55,11 @@ class Prober(object):
def probe_count(self):
return self._probe_count
def _parse_probe_count(self, probe_count):
return int(probe_count)
@property
def random_type(self):
return self._random_type
def _parse_random_type(self, random_type):
if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
return random_type
# ---Probing methods---
def gen_parallel_probe(self,callee):
def gen_parallel_probe(self, callee):
for i in range(self.probe_count):
yield (callee, self.get_probe(i))
......@@ -87,7 +76,7 @@ class Prober(object):
pool = Pool(self._ncpu)
for i in pool.imap_unordered(self.evaluate_probe_parallel,
self.gen_parallel_probe(callee)):
self.finish_probe(i[0],i[1])
self.finish_probe(i[0], i[1])
def evaluate_probe_parallel(self, argtuple):
callee = argtuple[0]
......@@ -104,8 +93,7 @@ class Prober(object):
def generate_probe(self):
""" a random-probe generator """
f = Field.from_random(random_type=self.random_type,
domain=self.domain,
dtype=self.probe_dtype.type)
domain=self.domain, dtype=self.probe_dtype)
uid = self._uid_counter
self._uid_counter += 1
return (uid, f)
......
......@@ -53,7 +53,19 @@ class RGSpace(Space):
if np.isscalar(shape):
shape = (shape,)
self._shape = tuple(int(i) for i in shape)
self._distances = self._parse_distances(distances)
if distances is None:
if self.harmonic:
self._distances = (1.,) * len(self._shape)
else:
self._distances = tuple(1./s for s in self._shape)
elif np.isscalar(distances):
self._distances = (float(distances),) * len(self._shape)
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
self._distances = tuple(temp)
self._dvol = float(reduce(lambda x, y: x*y, self._distances))
self._dim = int(reduce(lambda x, y: x*y, self._shape))
......@@ -169,14 +181,3 @@ class RGSpace(Space):
distance between neighboring grid points along the n-th dimension.
"""
return self._distances
def _parse_distances(self, distances):
if distances is None:
if self.harmonic:
temp = np.ones_like(self.shape, dtype=np.float64)
else:
temp = 1./np.array(self.shape, dtype=np.float64)
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
return tuple(temp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment