diff --git a/bfps/DNS.py b/bfps/DNS.py
index 8cbe8d9cc3c165054212b363a962d10471013a44..7a97e2f1132ea9862e2a8476bcd03acacd72296c 100644
--- a/bfps/DNS.py
+++ b/bfps/DNS.py
@@ -151,9 +151,11 @@ class DNS(_code):
         self.parameters['nu'] = float(0.1)
         self.parameters['fmode'] = int(1)
         self.parameters['famplitude'] = float(0.5)
+        self.parameters['energy'] = float(0.5)
+        self.parameters['injection_rate'] = float(0.4)
         self.parameters['fk0'] = float(2.0)
         self.parameters['fk1'] = float(4.0)
-        self.parameters['forcing_type'] = 'linear'
+        self.parameters['forcing_type'] = 'fixed_energy_injection_rate'
         self.parameters['histogram_bins'] = int(256)
         self.parameters['max_velocity_estimate'] = float(1)
         self.parameters['max_vorticity_estimate'] = float(1)
diff --git a/bfps/cpp/full_code/NSVE.cpp b/bfps/cpp/full_code/NSVE.cpp
index 1e24c7af531e7184f75b1f14257d42b822db7a9c..ba8a3ed65eaf17738e363f3e8c82246f1dfa5dd7 100644
--- a/bfps/cpp/full_code/NSVE.cpp
+++ b/bfps/cpp/full_code/NSVE.cpp
@@ -47,6 +47,8 @@ int NSVE<rnumber>::initialize(void)
     this->fs->nu = nu;
     this->fs->fmode = fmode;
     this->fs->famplitude = famplitude;
+    this->fs->energy = energy;
+    this->fs->injection_rate = injection_rate;
     this->fs->fk0 = fk0;
     this->fs->fk1 = fk1;
     strncpy(this->fs->forcing_type, forcing_type, 128);
diff --git a/bfps/cpp/full_code/NSVE.hpp b/bfps/cpp/full_code/NSVE.hpp
index d444b71ceb48ea19dc292a57cc91ac81157e15ed..e3f6b2765875084ed152e914cd285e331502d673 100644
--- a/bfps/cpp/full_code/NSVE.hpp
+++ b/bfps/cpp/full_code/NSVE.hpp
@@ -44,6 +44,8 @@ class NSVE: public direct_numerical_simulation
         double famplitude;
         double fk0;
         double fk1;
+        double energy;
+        double injection_rate;
         int fmode;
         char forcing_type[512];
         int histogram_bins;
diff --git a/bfps/cpp/vorticity_equation.cpp b/bfps/cpp/vorticity_equation.cpp
index 6266050569d2bd59c04f3641f1548432cbd101b0..86e4a97ec3c117482af2546a7d3d0dc28817b8ff 100644
--- a/bfps/cpp/vorticity_equation.cpp
+++ b/bfps/cpp/vorticity_equation.cpp
@@ -228,8 +228,7 @@ template <class rnumber,
           field_backend be>
 void vorticity_equation<rnumber, be>::add_forcing(
         field<rnumber, be, THREE> *dst,
-        field<rnumber, be, THREE> *vort_field,
-        rnumber factor)
+        field<rnumber, be, THREE> *vort_field)
 {
     TIMEZONE("vorticity_equation::add_forcing");
     if (strcmp(this->forcing_type, "none") == 0)
@@ -239,13 +238,13 @@ void vorticity_equation<rnumber, be>::add_forcing(
         ptrdiff_t cindex;
         if (this->cvorticity->clayout->myrank == this->cvorticity->clayout->rank[0][this->fmode])
         {
-            cindex = ((this->fmode - this->cvorticity->clayout->starts[0]) * this->cvorticity->clayout->sizes[1])*this->cvorticity->clayout->sizes[2];
-            dst->cval(cindex,2, 0) -= this->famplitude*factor/2;
+            cindex = dst->get_cindex(0, (this->fmode - this->cvorticity->clayout->starts[0]), 0);
+            dst->cval(cindex,2, 0) -= this->famplitude/2;
         }
         if (this->cvorticity->clayout->myrank == this->cvorticity->clayout->rank[0][this->cvorticity->clayout->sizes[0] - this->fmode])
         {
-            cindex = ((this->cvorticity->clayout->sizes[0] - this->fmode - this->cvorticity->clayout->starts[0]) * this->cvorticity->clayout->sizes[1])*this->cvorticity->clayout->sizes[2];
-            dst->cval(cindex, 2, 0) -= this->famplitude*factor/2;
+            cindex = dst->get_cindex(0, (this->cvorticity->clayout->sizes[0] - this->fmode - this->cvorticity->clayout->starts[0]), 0);
+            dst->cval(cindex, 2, 0) -= this->famplitude/2;
         }
         return;
     }
@@ -260,10 +259,99 @@ void vorticity_equation<rnumber, be>::add_forcing(
                                 this->kk->ky[yindex]*this->kk->ky[yindex] +
                                 this->kk->kz[zindex]*this->kk->kz[zindex]);
             if ((this->fk0 <= knorm) &&
+                (this->fk1 >= knorm))
+                for (int c=0; c<3; c++)
+                    for (int i=0; i<2; i++)
+                        dst->cval(cindex,c,i) += this->famplitude*vort_field->cval(cindex,c,i);
+        }
+        );
+        return;
+    }
+    if (strcmp(this->forcing_type, "fixed_energy_injection_rate") == 0)
+    {
+        // first, compute energy in shell
+        double energy_in_shell = 0;
+        this->kk->CLOOP_K2(
+                    [&](ptrdiff_t cindex,
+                        ptrdiff_t xindex,
+                        ptrdiff_t yindex,
+                        ptrdiff_t zindex,
+                        double k2){
+            double knorm = sqrt(k2);
+            if ((k2 > 0) &&
+                (this->fk0 <= knorm) &&
+                (this->fk1 >= knorm))
+                    energy_in_shell += (
+                            vort_field->cval(cindex, 0, 0)*vort_field->cval(cindex, 0, 0) + vort_field->cval(cindex, 0, 1)*vort_field->cval(cindex, 0, 1) +
+                            vort_field->cval(cindex, 1, 0)*vort_field->cval(cindex, 1, 0) + vort_field->cval(cindex, 1, 1)*vort_field->cval(cindex, 1, 1) +
+                            vort_field->cval(cindex, 2, 0)*vort_field->cval(cindex, 2, 0) + vort_field->cval(cindex, 2, 1)*vort_field->cval(cindex, 2, 1)
+                            ) / k2;
+        }
+        );
+        // divide by 2, because we want energy
+        energy_in_shell /= 2;
+        // now, add forcing term
+        double temp_famplitude = this->injection_rate / energy_in_shell;
+        this->kk->CLOOP_K2(
+                    [&](ptrdiff_t cindex,
+                        ptrdiff_t xindex,
+                        ptrdiff_t yindex,
+                        ptrdiff_t zindex,
+                        double k2){
+            double knorm = sqrt(k2);
+            if ((this->fk0 <= knorm) &&
+                (this->fk1 >= knorm))
+                for (int c=0; c<3; c++)
+                    for (int i=0; i<2; i++)
+                        dst->cval(cindex,c,i) += temp_famplitude*vort_field->cval(cindex,c,i);
+        }
+        );
+        return;
+    }
+    if (strcmp(this->forcing_type, "fixed_energy") == 0)
+    {
+        // first, compute energy in shell
+        double energy_in_shell = 0;
+        double total_energy = 0;
+        this->kk->CLOOP_K2(
+                    [&](ptrdiff_t cindex,
+                        ptrdiff_t xindex,
+                        ptrdiff_t yindex,
+                        ptrdiff_t zindex,
+                        double k2){
+            if (k2 > 0)
+            {
+                double local_energy = (
+                            vort_field->cval(cindex, 0, 0)*vort_field->cval(cindex, 0, 0) + vort_field->cval(cindex, 0, 1)*vort_field->cval(cindex, 0, 1) +
+                            vort_field->cval(cindex, 1, 0)*vort_field->cval(cindex, 1, 0) + vort_field->cval(cindex, 1, 1)*vort_field->cval(cindex, 1, 1) +
+                            vort_field->cval(cindex, 2, 0)*vort_field->cval(cindex, 2, 0) + vort_field->cval(cindex, 2, 1)*vort_field->cval(cindex, 2, 1)
+                            ) / k2;
+                total_energy += local_energy;
+                double knorm = sqrt(k2);
+                if ((this->fk0 <= knorm) &&
                     (this->fk1 >= knorm))
+                    energy_in_shell += local_energy;
+            }
+        }
+        );
+        // divide by 2, because we want energy
+        total_energy /= 2;
+        energy_in_shell /= 2;
+        // now, add forcing term
+        // see Michael's thesis, page 38
+        double temp_famplitude = sqrt((this->energy - total_energy + energy_in_shell) / energy_in_shell);
+        this->kk->CLOOP_K2(
+                    [&](ptrdiff_t cindex,
+                        ptrdiff_t xindex,
+                        ptrdiff_t yindex,
+                        ptrdiff_t zindex,
+                        double k2){
+            double knorm = sqrt(k2);
+            if ((this->fk0 <= knorm) &&
+                (this->fk1 >= knorm))
                 for (int c=0; c<3; c++)
                     for (int i=0; i<2; i++)
-                        dst->cval(cindex,c,i) += this->famplitude*vort_field->cval(cindex,c,i)*factor;
+                        dst->cval(cindex,c,i) += temp_famplitude*vort_field->cval(cindex,c,i);
         }
         );
         return;
@@ -320,7 +408,7 @@ void vorticity_equation<rnumber, be>::omega_nonlin(
             this->u->cval(cindex, cc, i) = tmp[cc][i];
     }
     );
-    this->add_forcing(this->u, this->v[src], 1.0);
+    this->add_forcing(this->u, this->v[src]);
     this->kk->template force_divfree<rnumber>(this->u->get_cdata());
 }
 
diff --git a/bfps/cpp/vorticity_equation.hpp b/bfps/cpp/vorticity_equation.hpp
index e8bd1d843f730d39439bc99703956dc623ca4e42..21a5f0391e6bc5e7666ef3e823a65bf2bbf1f1be 100644
--- a/bfps/cpp/vorticity_equation.hpp
+++ b/bfps/cpp/vorticity_equation.hpp
@@ -67,9 +67,11 @@ class vorticity_equation
 
         /* physical parameters */
         double nu;
-        int fmode;         // for Kolmogorov flow
-        double famplitude; // both for Kflow and band forcing
-        double fk0, fk1;   // for band forcing
+        int fmode;             // for Kolmogorov flow
+        double famplitude;     // both for Kflow and band forcing
+        double fk0, fk1;       // for band forcing
+        double injection_rate; // for fixed energy injection rate
+        double energy;         // for fixed energy
         char forcing_type[128];
 
         /* constructor, destructor */
@@ -89,8 +91,7 @@ class vorticity_equation
         void step(double dt);
         void impose_zero_modes(void);
         void add_forcing(field<rnumber, be, THREE> *dst,
-                         field<rnumber, be, THREE> *src_vorticity,
-                         rnumber factor);
+                         field<rnumber, be, THREE> *src_vorticity);
         void compute_vorticity(void);
         void compute_velocity(field<rnumber, be, THREE> *vorticity);
 
diff --git a/bfps/test/test_bfps_NSVEparticles.py b/bfps/test/test_bfps_NSVEparticles.py
index ab77e2103ccda7685cebe759f8e11cfe2a5b5ec9..33212e7670728cfeb1f180d2dc51d37653724e86 100644
--- a/bfps/test/test_bfps_NSVEparticles.py
+++ b/bfps/test/test_bfps_NSVEparticles.py
@@ -18,6 +18,7 @@ def main():
             ['NSVEparticles',
              '-n', '32',
              '--src-simname', 'B32p1e4',
+             '--forcing_type', 'linear',
              '--src-wd', bfps.lib_dir + '/test',
              '--src-iteration', '0',
              '--simname', 'dns_nsveparticles',