162 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			C++
		
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			C++
		
	
	
| /*
 | |
|  * Tensor3Expression.h
 | |
|  * @brief Tensor expression templates based on http://www.gps.caltech.edu/~walter/FTensor/FTensor.pdf
 | |
|  * Created on: Feb 10, 2010
 | |
|  * @author: Frank Dellaert
 | |
|  */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include <iostream>
 | |
| #include "tensors.h"
 | |
| 
 | |
| namespace tensors {
 | |
| 
 | |
| 	/** templated class to interface to an object A as a rank 3 tensor */
 | |
| 	template<class A, class I, class J, class K> class Tensor3Expression {
 | |
| 		A iter;
 | |
| 
 | |
| 		typedef Tensor3Expression<A, I, J, K> This;
 | |
| 
 | |
| 		/** Helper class for instantiating one index */
 | |
| 		class FixK_ {
 | |
| 			const int k;
 | |
| 			const A iter;
 | |
| 		public:
 | |
| 			FixK_(int k_, const A &a) :
 | |
| 				k(k_), iter(a) {
 | |
| 			}
 | |
| 			double operator()(int i, int j) const {
 | |
| 				return iter(i, j, k);
 | |
| 			}
 | |
| 		};
 | |
| 
 | |
| 		/** Helper class for contracting rank3 and rank1 tensor */
 | |
| 		template<class B> class TimesRank1_ {
 | |
| 			typedef Tensor1Expression<B, I> Rank1;
 | |
| 			const This T;
 | |
| 			const Rank1 t;
 | |
| 		public:
 | |
| 			TimesRank1_(const This &a, const Rank1 &b) :
 | |
| 				T(a), t(b) {
 | |
| 			}
 | |
| 			double operator()(int j, int k) const {
 | |
| 				double sum = 0.0;
 | |
| 				for (int i = 0; i < I::dim; i++)
 | |
| 					sum += T(i, j, k) * t(i);
 | |
| 				return sum;
 | |
| 			}
 | |
| 		};
 | |
| 
 | |
| 	public:
 | |
| 
 | |
| 		/** constructor */
 | |
| 		Tensor3Expression(const A &a) :
 | |
| 			iter(a) {
 | |
| 		}
 | |
| 
 | |
| 		/** Print */
 | |
| 		void print(const std::string& s = "Tensor3:") const {
 | |
| 			std::cout << s << "{";
 | |
| 			(*this)(0).print("");
 | |
| 			for (int k = 1; k < K::dim; k++) {
 | |
| 				std::cout << ",";
 | |
| 				(*this)(k).print("");
 | |
| 			}
 | |
| 			std::cout << "}" << std::endl;
 | |
| 		}
 | |
| 
 | |
| 		template<class B>
 | |
| 		bool equals(const Tensor3Expression<B, I, J, K> & q, double tol) const {
 | |
| 			for (int k = 0; k < K::dim; k++)
 | |
| 				if (!(*this)(k).equals(q(k), tol)) return false;
 | |
| 			return true;
 | |
| 		}
 | |
| 
 | |
| 		/** element access */
 | |
| 		double operator()(int i, int j, int k) const {
 | |
| 			return iter(i, j, k);
 | |
| 		}
 | |
| 
 | |
| 		/** Fix a single index */
 | |
| 		Tensor2Expression<FixK_, I, J> operator()(int k) const {
 | |
| 			return FixK_(k, iter);
 | |
| 		}
 | |
| 
 | |
| 		/** Contracting with rank1 tensor */
 | |
| 		template<class B>
 | |
| 		inline Tensor2Expression<TimesRank1_<B> , J, K> operator*(
 | |
| 				const Tensor1Expression<B, I> &b) const {
 | |
| 			return TimesRank1_<B> (*this, b);
 | |
| 		}
 | |
| 
 | |
| 	}; // Tensor3Expression
 | |
| 
 | |
| 	/** Print */
 | |
| 	template<class A, class I, class J, class K>
 | |
| 	void print(const Tensor3Expression<A, I, J, K>& T, const std::string& s =
 | |
| 			"Tensor3:") {
 | |
| 		T.print(s);
 | |
| 	}
 | |
| 
 | |
| 	/** Helper class for outer product of rank2 and rank1 tensor */
 | |
| 	template<class A, class I, class J, class B, class K>
 | |
| 	class Rank2Rank1_ {
 | |
| 		typedef Tensor2Expression<A, I, J> Rank2;
 | |
| 		typedef Tensor1Expression<B, K> Rank1;
 | |
| 		const Rank2 iterA;
 | |
| 		const Rank1 iterB;
 | |
| 	public:
 | |
| 		Rank2Rank1_(const Rank2 &a, const Rank1 &b) :
 | |
| 			iterA(a), iterB(b) {
 | |
| 		}
 | |
| 		double operator()(int i, int j, int k) const {
 | |
| 			return iterA(i, j) * iterB(k);
 | |
| 		}
 | |
| 	};
 | |
| 
 | |
| 	/** outer product of rank2 and rank1 tensor */
 | |
| 	template<class A, class I, class J, class B, class K>
 | |
| 	inline Tensor3Expression<Rank2Rank1_<A, I, J, B, K> , I, J, K> operator*(
 | |
| 			const Tensor2Expression<A, I, J>& a, const Tensor1Expression<B, K> &b) {
 | |
| 		return Rank2Rank1_<A, I, J, B, K> (a, b);
 | |
| 	}
 | |
| 
 | |
| 	/** Helper class for outer product of rank1 and rank2 tensor */
 | |
| 	template<class A, class I, class B, class J, class K>
 | |
| 	class Rank1Rank2_ {
 | |
| 		typedef Tensor1Expression<A, I> Rank1;
 | |
| 		typedef Tensor2Expression<B, J, K> Rank2;
 | |
| 		const Rank1 iterA;
 | |
| 		const Rank2 iterB;
 | |
| 	public:
 | |
| 		Rank1Rank2_(const Rank1 &a, const Rank2 &b) :
 | |
| 			iterA(a), iterB(b) {
 | |
| 		}
 | |
| 		double operator()(int i, int j, int k) const {
 | |
| 			return iterA(i) * iterB(j, k);
 | |
| 		}
 | |
| 	};
 | |
| 
 | |
| 	/** outer product of rank2 and rank1 tensor */
 | |
| 	template<class A, class I, class J, class B, class K>
 | |
| 	inline Tensor3Expression<Rank1Rank2_<A, I, B, J, K> , I, J, K> operator*(
 | |
| 			const Tensor1Expression<A, I>& a, const Tensor2Expression<B, J, K> &b) {
 | |
| 		return Rank1Rank2_<A, I, B, J, K> (a, b);
 | |
| 	}
 | |
| 
 | |
| 	/**
 | |
| 	 * This template works for any two expressions
 | |
| 	 */
 | |
| 	template<class A, class B, class I, class J, class K>
 | |
| 	bool assert_equality(const Tensor3Expression<A, I, J, K>& expected,
 | |
| 			const Tensor3Expression<B, I, J, K>& actual, double tol = 1e-9) {
 | |
| 		if (actual.equals(expected, tol)) return true;
 | |
| 		std::cout << "Not equal:\n";
 | |
| 		expected.print("expected:\n");
 | |
| 		actual.print("actual:\n");
 | |
| 		return false;
 | |
| 	}
 | |
| 
 | |
| } // namespace tensors
 |