49 #ifndef Intrepid2_DirectSumBasis_h 50 #define Intrepid2_DirectSumBasis_h 52 #include <Kokkos_View.hpp> 53 #include <Kokkos_DynRankView.hpp> 67 template<
typename BasisBaseClass>
71 using BasisBase = BasisBaseClass;
72 using BasisPtr = Teuchos::RCP<BasisBase>;
74 using DeviceType =
typename BasisBase::DeviceType;
75 using ExecutionSpace =
typename BasisBase::ExecutionSpace;
76 using OutputValueType =
typename BasisBase::OutputValueType;
77 using PointValueType =
typename BasisBase::PointValueType;
79 using OrdinalTypeArray1DHost =
typename BasisBase::OrdinalTypeArray1DHost;
80 using OrdinalTypeArray2DHost =
typename BasisBase::OrdinalTypeArray2DHost;
81 using OutputViewType =
typename BasisBase::OutputViewType;
82 using PointViewType =
typename BasisBase::PointViewType;
83 using ScalarViewType =
typename BasisBase::ScalarViewType;
96 basis1_(basis1),basis2_(basis2)
98 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument,
"basis1 and basis2 must agree in basis type");
99 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
100 std::invalid_argument,
"basis1 and basis2 must agree in cell topology");
101 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
102 std::invalid_argument,
"basis1 and basis2 must agree in coordinate system");
104 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
105 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
108 std::ostringstream basisName;
109 basisName << basis1->getName() <<
" + " << basis2->getName();
110 name_ = basisName.str();
113 this->basisCellTopology_ = basis1->getBaseCellTopology();
114 this->basisType_ = basis1->getBasisType();
115 this->basisCoordinates_ = basis1->getCoordinateSystem();
117 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
119 int degreeLength = basis1_->getPolynomialDegreeLength();
120 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument,
"Basis1 and Basis2 must agree on polynomial degree length");
122 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis degree lookup",this->basisCardinality_,degreeLength);
124 for (
int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
126 int fieldOrdinal = fieldOrdinal1;
127 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
128 for (
int d=0; d<degreeLength; d++)
130 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
133 for (
int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
135 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
137 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
138 for (
int d=0; d<degreeLength; d++)
140 this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
147 const auto & cardinality = this->basisCardinality_;
150 const ordinal_type tagSize = 4;
151 const ordinal_type posScDim = 0;
152 const ordinal_type posScOrd = 1;
153 const ordinal_type posDfOrd = 2;
155 OrdinalTypeArray1DHost tagView(
"tag view", cardinality*tagSize);
157 shards::CellTopology cellTopo = this->basisCellTopology_;
159 unsigned spaceDim = cellTopo.getDimension();
161 ordinal_type basis2Offset = basis1_->getCardinality();
163 for (
unsigned d=0; d<=spaceDim; d++)
165 unsigned subcellCount = cellTopo.getSubcellCount(d);
166 for (
unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
168 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
169 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
171 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
172 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
174 ordinal_type fieldOrdinal;
175 if (localDofID < subcellDofCount1)
178 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
183 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
185 tagView(fieldOrdinal*tagSize+0) = d;
186 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
187 tagView(fieldOrdinal*tagSize+2) = localDofID;
188 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
194 this->setOrdinalTagData(this->tagToOrdinal_,
197 this->basisCardinality_,
215 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
216 if (numScalarFamilies1 > 0)
219 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
220 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument,
"When basis1 has scalar value, basis2 must also");
221 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
222 for (
int i=0; i<numScalarFamilies1; i++)
224 scalarFamilies[i] = basisValues1.
tensorData(i);
226 for (
int i=0; i<numScalarFamilies2; i++)
228 scalarFamilies[i+numScalarFamilies1] = basisValues2.
tensorData(i);
235 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.
vectorData().isValid(), std::invalid_argument,
"When basis1 does not have tensorData() defined, it must have a valid vectorData()");
236 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument,
"When basis1 has vector value, basis2 must also");
238 const auto & vectorData1 = basisValues1.
vectorData();
239 const auto & vectorData2 = basisValues2.
vectorData();
241 const int numFamilies1 = vectorData1.numFamilies();
242 const int numComponents = vectorData1.numComponents();
243 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument,
"basis1 and basis2 must agree on the number of components in each vector");
244 const int numFamilies2 = vectorData2.numFamilies();
246 const int numFamilies = numFamilies1 + numFamilies2;
249 for (
int i=0; i<numFamilies1; i++)
251 for (
int j=0; j<numComponents; j++)
253 vectorComponents[i][j] = vectorData1.getComponent(i,j);
256 for (
int i=0; i<numFamilies2; i++)
258 for (
int j=0; j<numComponents; j++)
260 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
277 const int basisCardinality1 = basis1_->getCardinality();
278 const int basisCardinality2 = basis2_->getCardinality();
279 const int basisCardinality = basisCardinality1 + basisCardinality2;
281 auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
282 auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
284 basis1_->getDofCoords(dofCoords1);
285 basis2_->getDofCoords(dofCoords2);
300 const int basisCardinality1 = basis1_->getCardinality();
301 const int basisCardinality2 = basis2_->getCardinality();
302 const int basisCardinality = basisCardinality1 + basisCardinality2;
304 auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
305 auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
307 basis1_->getDofCoeffs(dofCoeffs1);
308 basis2_->getDofCoeffs(dofCoeffs2);
319 return name_.c_str();
325 using BasisBase::getValues;
342 const EOperator operatorType = OPERATOR_VALUE )
const override 344 const int fieldStartOrdinal1 = 0;
345 const int numFields1 = basis1_->getCardinality();
346 const int fieldStartOrdinal2 = numFields1;
347 const int numFields2 = basis2_->getCardinality();
352 basis1_->getValues(basisValues1, inputPoints, operatorType);
353 basis2_->getValues(basisValues2, inputPoints, operatorType);
374 virtual void getValues( OutputViewType outputValues,
const PointViewType inputPoints,
375 const EOperator operatorType = OPERATOR_VALUE )
const override 377 int cardinality1 = basis1_->getCardinality();
378 int cardinality2 = basis2_->getCardinality();
380 auto range1 = std::make_pair(0,cardinality1);
381 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
382 if (outputValues.rank() == 2)
384 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
385 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
387 basis1_->getValues(outputValues1, inputPoints, operatorType);
388 basis2_->getValues(outputValues2, inputPoints, operatorType);
390 else if (outputValues.rank() == 3)
392 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
393 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
395 basis1_->getValues(outputValues1, inputPoints, operatorType);
396 basis2_->getValues(outputValues2, inputPoints, operatorType);
398 else if (outputValues.rank() == 4)
400 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
401 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
403 basis1_->getValues(outputValues1, inputPoints, operatorType);
404 basis2_->getValues(outputValues2, inputPoints, operatorType);
406 else if (outputValues.rank() == 5)
408 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
409 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
411 basis1_->getValues(outputValues1, inputPoints, operatorType);
412 basis2_->getValues(outputValues2, inputPoints, operatorType);
414 else if (outputValues.rank() == 6)
416 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
417 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
419 basis1_->getValues(outputValues1, inputPoints, operatorType);
420 basis2_->getValues(outputValues2, inputPoints, operatorType);
422 else if (outputValues.rank() == 7)
424 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
425 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
427 basis1_->getValues(outputValues1, inputPoints, operatorType);
428 basis2_->getValues(outputValues2, inputPoints, operatorType);
432 INTREPID2_TEST_FOR_EXCEPTION(
true, std::invalid_argument,
"Unsupported outputValues rank");
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
BasisValues< Scalar, ExecSpaceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
EOperator
Enumeration of primitive operators available in Intrepid. Primitive operators act on reconstructed fu...
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
const VectorDataType & vectorData() const
VectorData accessor.
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.