From d82637677b9f1a746e42f3e11e7af0141f22785e Mon Sep 17 00:00:00 2001
From: Pierre Navaro <pierre.navaro@math.cnrs.fr>
Date: Tue, 30 Oct 2018 13:34:27 +0100
Subject: [PATCH] use tensoropt and increase discretization

---
 src/interpolation.jl         |  6 +++---
 test/runtests.jl             | 13 -------------
 test/test_interpolation1d.jl |  2 +-
 test/test_interpolation2d.jl |  2 +-
 test/test_interpolation3d.jl | 10 +++++-----
 test/test_interpolation4d.jl | 12 ++++++------
 test/test_interpolation5d.jl | 14 +++++++-------
 7 files changed, 23 insertions(+), 36 deletions(-)

diff --git a/src/interpolation.jl b/src/interpolation.jl
index 7c37150..72b6864 100644
--- a/src/interpolation.jl
+++ b/src/interpolation.jl
@@ -95,7 +95,7 @@ function interpolate(interp_x :: InterpolationType,
     y = interp_y(ye) 
     z = interp_z(ze) 
 
-    @tensor begin
+    @tensoropt begin
         s[e1,e2,e3] := x[e1,c1]*y[e2,c2]*z[e3,c3]*f[c1,c2,c3]
     end
     s
@@ -139,7 +139,7 @@ function interpolate(interp_x :: InterpolationType,
     z = interp_z(ze) 
     v = interp_v(ve) 
 
-    @tensor begin
+    @tensoropt begin
         s[e1,e2,e3,e4] := x[e1,c1]*y[e2,c2]*z[e3,c3]*v[e4,c4]*f[c1,c2,c3,c4]
     end
     s
@@ -190,7 +190,7 @@ function interpolate(interp_x :: InterpolationType,
     v = interp_v(ve) 
     w = interp_w(we) 
 
-    @tensor begin
+    @tensoropt begin
         s[e1,e2,e3,e4,e5] := x[e1,c1]*y[e2,c2]*z[e3,c3]*v[e4,c4]*w[e5,c5]*f[c1,c2,c3,c4,c5]
     end
     s
diff --git a/test/runtests.jl b/test/runtests.jl
index 4c16cf3..03fc0a9 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,4 @@
 using HermiteGF
-
 using Test
 
 include("trapz.jl")
@@ -8,15 +7,3 @@ include("test_interpolation2d.jl")
 include("test_interpolation3d.jl")
 include("test_interpolation4d.jl")
 include("test_interpolation5d.jl")
-
-#=
-plot!( nvec, errors["1D"];
-       title  = "L2 error scaling",
-       markershape = :circle, 
-       label  = "1D",
-       yscale = :log10)
-
-xlabel!("N")
-ylabel!("L2 error")
-savefig("errors.png")
-=#
diff --git a/test/test_interpolation1d.jl b/test/test_interpolation1d.jl
index 3f3090d..cbee3fb 100644
--- a/test/test_interpolation1d.jl
+++ b/test/test_interpolation1d.jl
@@ -46,7 +46,7 @@ end
         dx  = xe[2]-xe[1]
         fe  = cos.(xe.^2)
 
-        s   = interpolate( interp, fk, xe )
+        @time s   = interpolate( interp, fk, xe )
 
         l2_error = sqrt(trapz((s - fe).^2, dx))
         l1_error = maximum(abs.(s .- fe))
diff --git a/test/test_interpolation2d.jl b/test/test_interpolation2d.jl
index cc655fd..5533e73 100644
--- a/test/test_interpolation2d.jl
+++ b/test/test_interpolation2d.jl
@@ -23,7 +23,7 @@ fk = f(xk, yk)
 
 @testset "Hermite 2D" begin
 	
-    s = interpolate( hermite_x, hermite_y, fk, xe, ye )
+    @time s = interpolate( hermite_x, hermite_y, fk, xe, ye )
     
     max_error = maximum(abs.(s .- f(xe, ye)))
     l2_error = sqrt(trapz((s .- f(xe, ye)).^2, dx, dy))
diff --git a/test/test_interpolation3d.jl b/test/test_interpolation3d.jl
index 2d8d641..1addda4 100644
--- a/test/test_interpolation3d.jl
+++ b/test/test_interpolation3d.jl
@@ -2,9 +2,9 @@
 
     f(x, y, z) = cos(x^2 + y^2 + z^2)
 
-    xmin, xmax, nx = -1, 1, 16
-    ymin, ymax, ny = -1, 1, 16
-    zmin, zmax, nz = -1, 1, 16
+    xmin, xmax, nx = -1, 1, 32
+    ymin, ymax, ny = -1, 1, 32
+    zmin, zmax, nz = -1, 1, 32
 
     ϵ = 0.1
     γ = 3.0
@@ -17,7 +17,7 @@
     yk = hermite_y.nodes
     zk = hermite_z.nodes
 
-    nxe, nye, nze = 16, 32, 64
+    nxe, nye, nze = 32, 64, 128
 
     xe = collect(range(xmin, stop=xmax, length=nxe))
     ye = collect(range(ymin, stop=ymax, length=nye))
@@ -29,7 +29,7 @@
     fk = [f(x, y, z) for x in xk, y in yk, z in zk]
     fe = [f(x, y, z) for x in xe, y in ye, z in ze]
 
-    s  = interpolate(hermite_x, hermite_y, hermite_z, fk, xe, ye, ze)
+    @time s  = interpolate(hermite_x, hermite_y, hermite_z, fk, xe, ye, ze)
 
     max_error = maximum(abs.(s .- fe))
 
diff --git a/test/test_interpolation4d.jl b/test/test_interpolation4d.jl
index 574597f..f90f98b 100644
--- a/test/test_interpolation4d.jl
+++ b/test/test_interpolation4d.jl
@@ -2,10 +2,10 @@
 
     f(x, y, z, v) = cos(x^2 + y^2 + z^2 + v^2)
 
-    xmin, xmax, nx = -1, 1, 7
-    ymin, ymax, ny = -1, 1, 7
-    zmin, zmax, nz = -1, 1, 7
-    vmin, vmax, nv = -1, 1, 7
+    xmin, xmax, nx = -1, 1, 32
+    ymin, ymax, ny = -1, 1, 32
+    zmin, zmax, nz = -1, 1, 32
+    vmin, vmax, nv = -1, 1, 32
 
     ϵ = 0.1
     γ = 3.0
@@ -20,7 +20,7 @@
     zk = hermite_z.nodes
     vk = hermite_v.nodes
 
-    nxe, nye, nze, nve = 8, 16, 16, 8
+    nxe, nye, nze, nve = 32, 16, 16, 32
 
     xe = collect(range(xmin, stop=xmax, length=nxe))
     ye = collect(range(ymin, stop=ymax, length=nye))
@@ -34,7 +34,7 @@
     fk = [f(x, y, z, v) for x in xk, y in yk, z in zk, v in vk]
     fe = [f(x, y, z, v) for x in xe, y in ye, z in ze, v in ve]
 
-    s  = interpolate(hermite_x, hermite_y, hermite_z, hermite_v, fk, xe, ye, ze, ve)
+    @time s  = interpolate(hermite_x, hermite_y, hermite_z, hermite_v, fk, xe, ye, ze, ve)
 
     max_error = maximum(abs.(s .- fe))
 
diff --git a/test/test_interpolation5d.jl b/test/test_interpolation5d.jl
index b1156ab..35ad4f7 100644
--- a/test/test_interpolation5d.jl
+++ b/test/test_interpolation5d.jl
@@ -2,11 +2,11 @@
 
     f(x, y, z, v, w) = cos(x^2 + y^2 + z^2 + v^2 + w^2)
 
-    xmin, xmax, nx = -1, 1, 7
-    ymin, ymax, ny = -1, 1, 7
-    zmin, zmax, nz = -1, 1, 7
-    vmin, vmax, nv = -1, 1, 7
-    wmin, wmax, nw = -1, 1, 7
+    xmin, xmax, nx = -1, 1, 32
+    ymin, ymax, ny = -1, 1, 32
+    zmin, zmax, nz = -1, 1, 32
+    vmin, vmax, nv = -1, 1, 32
+    wmin, wmax, nw = -1, 1, 32
 
     ϵ = 0.1
     γ = 3.0
@@ -23,7 +23,7 @@
     vk = hermite_v.nodes
     wk = hermite_w.nodes
 
-    nxe, nye, nze, nve, nwe = 8, 4, 8, 8, 4
+    nxe, nye, nze, nve, nwe = 32, 32, 32, 32, 32
 
     xe = collect(range(xmin, stop=xmax, length=nxe))
     ye = collect(range(ymin, stop=ymax, length=nye))
@@ -39,7 +39,7 @@
     fk = [f(x, y, z, v, w) for x in xk, y in yk, z in zk, v in vk, w in wk]
     fe = [f(x, y, z, v, w) for x in xe, y in ye, z in ze, v in ve, w in we]
 
-    s  = interpolate(hermite_x, hermite_y, hermite_z, hermite_v, hermite_w, fk, xe, ye, ze, ve, we)
+    @time s  = interpolate(hermite_x, hermite_y, hermite_z, hermite_v, hermite_w, fk, xe, ye, ze, ve, we)
 
     max_error = maximum(abs.(s .- fe))
 
-- 
GitLab