/* Copyright (c) 2008-2025 the MRtrix3 contributors.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * Covered Software is provided under this License on an "as is"
 * basis, without warranty of any kind, either expressed, implied, or
 * statutory, including, without limitation, warranties that the
 * Covered Software is free of defects, merchantable, fit for a
 * particular purpose or non-infringing.
 * See the Mozilla Public License v. 2.0 for more details.
 *
 * For more details, see http://www.mrtrix.org/.
 */

#include "dwi/tractography/GT/particlegrid.h"

namespace MR {
  namespace DWI {
    namespace Tractography {
      namespace GT {


        ParticleGrid::ParticleGrid(const Header& H)
        {
          DEBUG("Initialise particle grid.");
          // define (isotropic) grid spacing
          default_type vox = std::min({H.spacing(0), H.spacing(1), H.spacing(2)});
          grid_spacing = std::max(2.0 * Particle::L, vox);

          // set grid dimensions
          dims[0] = Math::ceil<size_t>( (H.size(0)-1) * H.spacing(0) / grid_spacing ) + 1;
          dims[1] = Math::ceil<size_t>( (H.size(1)-1) * H.spacing(1) / grid_spacing ) + 1;
          dims[2] = Math::ceil<size_t>( (H.size(2)-1) * H.spacing(2) / grid_spacing ) + 1;
          grid.resize(dims[0]*dims[1]*dims[2]);

          // Initialise scanner-to-grid transform
          Eigen::DiagonalMatrix<default_type, 3> newspacing (grid_spacing, grid_spacing, grid_spacing);
          transform_type T_g2s = H.transform() * newspacing;
          T_s2g = T_g2s.inverse();
        }


        void ParticleGrid::add(const Point_t &pos, const Point_t &dir)
        {
          Particle* p = pool.create(pos, dir);
          size_t gidx = pos2idx(pos);
          std::lock_guard<std::mutex> lock (mutex);
          grid[gidx].push_back(p);
        }

        void ParticleGrid::shift(Particle *p, const Point_t& pos, const Point_t& dir)
        {
          size_t gidx0 = pos2idx(p->getPosition());
          size_t gidx1 = pos2idx(pos);
          std::lock_guard<std::mutex> lock (mutex);
          grid[gidx0].remove(p);
          p->setPosition(pos);
          p->setDirection(dir);
          grid[gidx1].push_back(p);
        }

        void ParticleGrid::remove(Particle* p)
        {
          size_t gidx0 = pos2idx(p->getPosition());
          {
            std::lock_guard<std::mutex> lock (mutex);
            grid[gidx0].remove (p);// (std::remove(grid[gidx0].begin(), grid[gidx0].end(), p), grid[gidx0].end());
          }
          pool.destroy(p);
        }

        void ParticleGrid::clear()
        {
          grid.clear();
          pool.clear();
        }

        const ParticleGrid::ParticleContainer* ParticleGrid::at(const ssize_t x, const ssize_t y, const ssize_t z) const
        {
          if ((x < 0) || (size_t(x) >= dims[0]) || (y < 0) || (size_t(y) >= dims[1]) || (z < 0) || (size_t(z) >= dims[2]))  // out of bounds
            return nullptr;
          return &grid[xyz2idx(x, y, z)];
        }

        void ParticleGrid::exportTracks(Tractography::Writer<float> &writer)
        {
          std::lock_guard<std::mutex> lock (mutex);
          // Initialise
          Particle* par;
          Particle* nextpar;
          int alpha = 0;
          vector<Point_t> track;
          // Loop through all unvisited particles
          for (ParticleContainer& gridvox : grid)
          {
            for (Particle* par0 : gridvox)
            {
              par = par0;
              if (!par->isVisited())
              {
                par->setVisited(true);
                // forward
                track.push_back(par->getPosition());
                alpha = +1;
                while ((alpha == +1) ? par->hasSuccessor() : par->hasPredecessor())
                {
                  nextpar = (alpha == +1) ? par->getSuccessor() : par->getPredecessor();
                  alpha = (nextpar->getPredecessor() == par) ? +1 : -1;
                  track.push_back(nextpar->getPosition());
                  nextpar->setVisited(true);
                  par = nextpar;
                }
                track.push_back(par->getEndPoint(alpha));
                // backward
                par = par0;
                std::reverse(track.begin(), track.end());
                alpha = -1;
                while ((alpha == +1) ? par->hasSuccessor() : par->hasPredecessor())
                {
                  nextpar = (alpha == +1) ? par->getSuccessor() : par->getPredecessor();
                  alpha = (nextpar->getPredecessor() == par) ? +1 : -1;
                  track.push_back(nextpar->getPosition());
                  nextpar->setVisited(true);
                  par = nextpar;
                }
                track.push_back(par->getEndPoint(alpha));
                if (track.size() > 1)
                  writer(track);
                track.clear();
              }
            }
          }
          // Free all particle locks
          for (ParticleContainer& gridvox : grid) {
            for (Particle* par : gridvox) {
                par->setVisited(false);
            }
          }
        }


      }
    }
  }
}
