Commit df8cb230 authored by Ultima's avatar Ultima
Browse files

field can now be initialized from another field.

parent ef955b69
......@@ -1922,7 +1922,7 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None, ishape=None,
**kwargs):
copy=False, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -1947,10 +1947,58 @@ class field(object):
"""
# If the given val was a field, try to cast it accordingly to the given
# domain and codomain, etc...
if isinstance(val, field):
self._init_from_field(f=val,
domain=domain,
codomain=codomain,
ishape=ishape,
copy=copy,
**kwargs)
else:
self._init_from_array(val=val,
domain=domain,
codomain=codomain,
ishape=ishape,
copy=copy,
**kwargs)
def _init_from_field(self, f, domain, codomain, ishape, copy, **kwargs):
# check domain
if domain is None:
domain = f.domain
# check codomain
if codomain is None:
if domain.check_codomain(f.codomain):
codomain = f.codomain
else:
codomain = domain.get_codomain()
# Check if the given field lives in the same fourier-space as the
# new domain
if f.domain.harmonic != domain.harmonic:
# Try to transform the given field to the given domain/codomain
f = f.transform(new_domain=domain,
new_codomain=codomain)
# Check if the domain is now really the same.
# This is necessary since iso-fourier-conversion is not implemented
if f.domain == domain:
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
ishape=ishape,
copy=copy,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Incompatible domain given."))
def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs):
# check domain
if not isinstance(domain, space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
self.domain = domain
# check codomain
......@@ -1985,7 +2033,7 @@ class field(object):
val = self._map(lambda: self.domain.get_random_values(
codomain=self.codomain,
**kwargs))
self.set_val(new_val=val)
self.set_val(new_val=val, copy=copy)
def __len__(self):
return int(self.get_dim(split=True)[0])
......@@ -2444,7 +2492,8 @@ class field(object):
return work_field
def transform(self, codomain=None, overwrite=False, **kwargs):
def transform(self, new_domain=None, new_codomain=None, overwrite=False,
**kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
......@@ -2469,24 +2518,31 @@ class field(object):
Otherwise, nothing is returned.
"""
if codomain is None:
codomain = self.codomain
if new_domain is None:
new_domain = self.codomain
if new_codomain is None:
# try to recycle the old domain
if new_domain.check_codomain(self.domain):
new_codomain = self.domain
else:
new_codomain = new_domain.get_codomain()
else:
assert(self.domain.check_codomain(codomain))
assert(new_domain.check_codomain(new_codomain))
new_val = self._map(
lambda z: self.domain.calc_transform(
z, codomain=codomain, **kwargs),
z, codomain=new_domain, **kwargs),
self.get_val())
if overwrite:
return_field = self
return_field.set_codomain(new_codomain=self.domain, force=True)
return_field.set_domain(new_domain=codomain, force=True)
return_field.set_codomain(new_codomain=new_codomain, force=True)
return_field.set_domain(new_domain=new_domain, force=True)
else:
return_field = self.copy_empty(domain=self.codomain,
codomain=self.domain)
return_field.set_val(new_val=new_val)
return_field = self.copy_empty(domain=new_domain,
codomain=new_codomain)
return_field.set_val(new_val=new_val, copy=False)
return return_field
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment