Commit f46fca01 authored by Jait Dixit's avatar Jait Dixit

Merge branch 'master' into tests

parents fd9d341d 185f454e
Pipeline #10084 passed with stage
in 19 minutes and 36 seconds
# -*- coding: utf-8 -*-
from nifty import Field, RGSpace, DiagonalProberMixin, TraceProberMixin,\
Prober, DiagonalOperator
class DiagonalProber(DiagonalProberMixin, Prober):
pass
class MultiProber(DiagonalProberMixin, TraceProberMixin, Prober):
pass
x = RGSpace((8, 8))
f = Field.from_random(domain=x, random_type='normal')
diagOp = DiagonalOperator(domain=x, diagonal=f)
diagProber = DiagonalProber(domain=x)
diagProber(diagOp)
print (f - diagProber.diagonal).norm()
multiProber = MultiProber(domain=x)
multiProber(diagOp)
print (f - multiProber.diagonal).norm()
print f.sum() - multiProber.trace
from nifty import * from nifty import *
#import plotly.offline as pl import plotly.offline as pl
#import plotly.graph_objs as go import plotly.graph_objs as go
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
rank = comm.rank rank = comm.rank
np.random.seed(42)
class WienerFilterEnergy(Energy): class WienerFilterEnergy(Energy):
def __init__(self, position, D, j): def __init__(self, position, D, j):
...@@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy): ...@@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy):
return_g.val = g.val.real return_g.val = g.val.real
return return_g return return_g
@property
def curvature(self):
class Dummy(object):
def __init__(self, x):
self.x = x
def inverse_times(self, *args, **kwargs):
return self.x.times(*args, **kwargs)
my_dummy = Dummy(self.D)
return my_dummy
@memo @memo
def D_inverse_x(self): def D_inverse_x(self):
return D.inverse_times(self.position) return D.inverse_times(self.position)
...@@ -82,14 +94,18 @@ if __name__ == "__main__": ...@@ -82,14 +94,18 @@ if __name__ == "__main__":
x = energy.position x = energy.position
print (iteration, ((x-ss).norm()/ss.norm()).real) print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer = SteepestDescent(convergence_tolerance=0, # minimizer = SteepestDescent(convergence_tolerance=0,
iteration_limit=50, # iteration_limit=50,
callback=distance_measure) # callback=distance_measure)
minimizer = RelaxedNewton(convergence_tolerance=0,
iteration_limit=2,
callback=distance_measure)
minimizer = VL_BFGS(convergence_tolerance=0, # minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50, # iteration_limit=50,
callback=distance_measure, # callback=distance_measure,
max_history_length=3) # max_history_length=3)
m0 = Field(s_space, val=1) m0 = Field(s_space, val=1)
...@@ -97,40 +113,35 @@ if __name__ == "__main__": ...@@ -97,40 +113,35 @@ if __name__ == "__main__":
(energy, convergence) = minimizer(energy) (energy, convergence) = minimizer(energy)
m = energy.position
d_data = d.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
sh_data = sh.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
j_data = j.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=j_data)], filename='j.html')
jabs_data = np.abs(j.val.get_full_data())
jphase_data = np.angle(j.val.get_full_data())
if rank == 0:
pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
m_data = m.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
#
#
# grad = gradient(m)
#
# d_data = d.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
#
# ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
#
# sh_data = sh.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
#
# j_data = j.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
# jabs_data = np.abs(j.val.get_full_data())
# jphase_data = np.angle(j.val.get_full_data())
# if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
#
# m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real # grad_data = grad.val.get_full_data().real
# if rank == 0: # if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html') # pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
...@@ -55,7 +55,7 @@ from spaces import * ...@@ -55,7 +55,7 @@ from spaces import *
from operators import * from operators import *
#from probing import * from probing import *
from sugar import * from sugar import *
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from keepers import Loggable
class Energy(object):
class Energy(Loggable, object):
def __init__(self, position): def __init__(self, position):
self._cache = {} self._cache = {}
try: try:
......
...@@ -34,7 +34,10 @@ class Field(Loggable, Versionable, object): ...@@ -34,7 +34,10 @@ class Field(Loggable, Versionable, object):
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
val=val) val=val)
self.set_val(new_val=val, copy=copy) if val is None:
self._val = None
else:
self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val=None): def _parse_domain(self, domain, val=None):
if domain is None: if domain is None:
...@@ -406,6 +409,9 @@ class Field(Loggable, Versionable, object): ...@@ -406,6 +409,9 @@ class Field(Loggable, Versionable, object):
return self return self
def get_val(self, copy=False): def get_val(self, copy=False):
if self._val is None:
self.set_val(None)
if copy: if copy:
return self._val.copy() return self._val.copy()
else: else:
...@@ -413,11 +419,11 @@ class Field(Loggable, Versionable, object): ...@@ -413,11 +419,11 @@ class Field(Loggable, Versionable, object):
@property @property
def val(self): def val(self):
return self._val return self.get_val(copy=False)
@val.setter @val.setter
def val(self, new_val): def val(self, new_val):
self._val = self.cast(new_val) self.set_val(new_val=new_val, copy=False)
@property @property
def shape(self): def shape(self):
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
import pickle import pickle
import numpy as np
from field_type import FieldType from field_type import FieldType
class FieldArray(FieldType): class FieldArray(FieldType):
def __init__(self, dtype, shape): def __init__(self, shape, dtype=np.float):
try: try:
new_shape = tuple([int(i) for i in shape]) new_shape = tuple([int(i) for i in shape])
except TypeError: except TypeError:
......
...@@ -5,3 +5,4 @@ from conjugate_gradient import ConjugateGradient ...@@ -5,3 +5,4 @@ from conjugate_gradient import ConjugateGradient
from quasi_newton_minimizer import QuasiNewtonMinimizer from quasi_newton_minimizer import QuasiNewtonMinimizer
from steepest_descent import SteepestDescent from steepest_descent import SteepestDescent
from vl_bfgs import VL_BFGS from vl_bfgs import VL_BFGS
from relaxed_newton import RelaxedNewton
...@@ -21,8 +21,9 @@ class RelaxedNewton(QuasiNewtonMinimizer): ...@@ -21,8 +21,9 @@ class RelaxedNewton(QuasiNewtonMinimizer):
gradient = energy.gradient gradient = energy.gradient
curvature = energy.curvature curvature = energy.curvature
descend_direction = curvature.inverse_times(gradient) descend_direction = curvature.inverse_times(gradient)
norm = descend_direction.norm() return descend_direction * -1
if norm != 1: #norm = descend_direction.norm()
return descend_direction / -norm # if norm != 1:
else: # return descend_direction / -norm
return descend_direction * -1 # else:
# return descend_direction * -1
...@@ -69,7 +69,7 @@ class FFTOperator(LinearOperator): ...@@ -69,7 +69,7 @@ class FFTOperator(LinearOperator):
self._backward_transformation = TransformationCache.create( self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module) backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces): def _times(self, x, spaces, dtype=None):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None: if spaces is None:
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
...@@ -87,12 +87,12 @@ class FFTOperator(LinearOperator): ...@@ -87,12 +87,12 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0] result_domain[spaces[0]] = self.target[0]
result_field = x.copy_empty(domain=result_domain) result_field = x.copy_empty(domain=result_domain, dtype=dtype)
result_field.set_val(new_val=new_val, copy=False) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
def _inverse_times(self, x, spaces): def _inverse_times(self, x, spaces, dtype=None):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None: if spaces is None:
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
...@@ -110,7 +110,7 @@ class FFTOperator(LinearOperator): ...@@ -110,7 +110,7 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0] result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain) result_field = x.copy_empty(domain=result_domain, dtype=dtype)
result_field.set_val(new_val=new_val, copy=False) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from mixin_base import MixinBase
from diagonal_prober_mixin import DiagonalProberMixin from diagonal_prober_mixin import DiagonalProberMixin
from trace_prober_mixin import TraceProberMixin from trace_prober_mixin import TraceProberMixin
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from mixin_base import MixinBase
class DiagonalProberMixin(object):
class DiagonalProberMixin(MixinBase): def __init__(self, *args, **kwargs):
def __init__(self):
self.reset() self.reset()
super(DiagonalProberMixin, self).__init__() super(DiagonalProberMixin, self).__init__(*args, **kwargs)
def reset(self): def reset(self):
self.__sum_of_probings = 0 self.__sum_of_probings = 0
......
# -*- coding: utf-8 -*-
class MixinBase(object):
def reset(self, *args, **kwargs):
pass
def finish_probe(self, *args, **kwargs):
pass
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from mixin_base import MixinBase
class TraceProberMixin(object):
class TraceProberMixin(MixinBase): def __init__(self, *args, **kwargs):
def __init__(self):
self.reset() self.reset()
super(TraceProberMixin, self).__init__() super(TraceProberMixin, self).__init__(*args, **kwargs)
def reset(self): def reset(self):
self.__sum_of_probings = 0 self.__sum_of_probings = 0
......
...@@ -34,8 +34,6 @@ class Prober(object): ...@@ -34,8 +34,6 @@ class Prober(object):
self._random_type = self._parse_random_type(random_type) self._random_type = self._parse_random_type(random_type)
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
super(Prober, self).__init__()
# ---Properties--- # ---Properties---
@property @property
...@@ -84,7 +82,7 @@ class Prober(object): ...@@ -84,7 +82,7 @@ class Prober(object):
self.finish_probe(current_probe, pre_result) self.finish_probe(current_probe, pre_result)
def reset(self): def reset(self):
super(Prober, self).reset() pass
def get_probe(self, index): def get_probe(self, index):
""" layer of abstraction for potential probe-caching """ """ layer of abstraction for potential probe-caching """
...@@ -107,7 +105,7 @@ class Prober(object): ...@@ -107,7 +105,7 @@ class Prober(object):
return callee(probe, **kwargs) return callee(probe, **kwargs)
def finish_probe(self, probe, pre_result): def finish_probe(self, probe, pre_result):
super(Prober, self).finish_probe(probe, pre_result) pass
def __call__(self, callee): def __call__(self, callee):
return self.probing_run(callee) return self.probing_run(callee)
...@@ -139,8 +139,8 @@ class GLSpace(Space): ...@@ -139,8 +139,8 @@ class GLSpace(Space):
if axes is not None: if axes is not None:
# reshape the weight array to match the input shape # reshape the weight array to match the input shape
new_shape = np.ones(len(x.shape), dtype=np.int) new_shape = np.ones(len(x.shape), dtype=np.int)
for index in range(len(axes)): # we know len(axes) is always 1
new_shape[index] = len(weight) new_shape[axes[0]] = len(weight)
weight = weight.reshape(new_shape) weight = weight.reshape(new_shape)
if inplace: if inplace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment