I implemented a fixed-size matrix class which supports basic matrix operations (cannot use C++11). What do you think, and how could it be improved?
#ifndef FIXED_SIZE_MATRIX_H__
#define FIXED_SIZE_MATRIX_H__
#include <iostream>
#include <cstdlib>
template <typename Ty, int N, int M = N>
struct FixedSizeMatrix {
typedef Ty value_type;
union {
struct {
Ty element[N][M];
};
struct {
Ty flatten[N * M];
};
};
// Access with bounds checking
Ty& operator()(int r, int c) {
assert(r >= 0 && r < N);
assert(c >= 0 && c < M);
return element[r][c];
}
// Return matrix transpose
FixedSizeMatrix<Ty, M, N> T() const {
FixedSizeMatrix<Ty, M, N> result;
for (int r = 0; r < N; ++r) {
for (int c = 0; c < M; ++c) {
result.element[c][r] = element[r][c];
}
}
return result;
}
// Return matrix inverse
FixedSizeMatrix<Ty, N, M> I();
};
///
// Matrix Inverse
///
// A design choice was made to keep FixedSizeMatrix an aggregate class to enable
// FixedSizeMatrix<float, 2, 2> = {1.0f, 1.0f, 0.5f, 0.2f} initialization.
// With c++11 it would be possible to create a FixedSizeMatrix base class
// and derive all variations from it, while retaining the brace
// initialization.
//
// Matrix inverse helpers
namespace detail {
// General version not implemented
template <typename Ty, int N, int M>
struct inverse;
// Matrix inversion for 2x2 matrix
template <typename Ty>
struct inverse<Ty, 2, 2> {
FixedSizeMatrix<Ty, 2, 2> operator()(FixedSizeMatrix<Ty, 2, 2> a) {
FixedSizeMatrix<Ty, 2, 2> result;
Ty det =
a.element[0][0] * a.element[1][1] - a.element[0][1] * a.element[1][0];
assert(det != 0);
result.element[0][0] = a.element[1][1] / det;
result.element[1][1] = a.element[0][0] / det;
result.element[0][1] = -a.element[0][1] / det;
result.element[1][0] = -a.element[1][0] / det;
return result;
}
};
} // detail
// Define matrix inverse
template <typename Ty, int N, int M>
FixedSizeMatrix<Ty, N, M> FixedSizeMatrix<Ty, N, M>::I() {
return detail::inverse<Ty, N, M>()(*this);
}
///
// Matrix operations
///
// Matrix product
template <typename Ty, int N, int M, int P>
FixedSizeMatrix<Ty, N, P> operator*(FixedSizeMatrix<Ty, N, M> a,
FixedSizeMatrix<Ty, M, P> b) {
FixedSizeMatrix<Ty, N, P> result;
for (int r = 0; r < N; ++r) {
for (int c = 0; c < P; ++c) {
Ty accum = Ty(0);
for (int i = 0; i < M; ++i) {
accum += a.element[r][i] * b.element[i][c];
}
result.element[r][c] = accum;
}
}
return result;
}
// Unary negation
template <typename Ty, int N, int M>
FixedSizeMatrix<Ty, N, M> operator-(FixedSizeMatrix<Ty, N, M> a) {
FixedSizeMatrix<Ty, N, M> result;
for (int e = 0; e < N * M; ++e) result.flatten[e] = -a.flatten[e];
return result;
}
#define MATRIX_WITH_MATRIX_OPERATOR(op_symbol, op) \
template <typename Ty, int N, int M> \
FixedSizeMatrix<Ty, N, M> operator op_symbol(FixedSizeMatrix<Ty, N, M> a, \
FixedSizeMatrix<Ty, N, M> b) { \
FixedSizeMatrix<Ty, N, M> result; \
for (int e = 0; e < N * M; ++e) \
result.flatten[e] = a.flatten[e] op b.flatten[e]; \
return result; \
}
MATRIX_WITH_MATRIX_OPERATOR(+, +);
MATRIX_WITH_MATRIX_OPERATOR(-, -);
#undef MATRIX_WITH_MATRIX_OPERATOR
#define MATRIX_WITH_SCALAR_OPERATOR(op_symbol, op) \
template <typename Ty, int N, int M> \
FixedSizeMatrix<Ty, N, M> operator op_symbol(FixedSizeMatrix<Ty, N, M> a, \
Ty scalar) { \
FixedSizeMatrix<Ty, N, M> result; \
for (int e = 0; e < N * M; ++e) \
result.flatten[e] = a.flatten[e] op scalar; \
return result; \
}
MATRIX_WITH_SCALAR_OPERATOR(+, +);
MATRIX_WITH_SCALAR_OPERATOR(-, -);
MATRIX_WITH_SCALAR_OPERATOR(*, *);
MATRIX_WITH_SCALAR_OPERATOR(/, / );
#undef MATRIX_WITH_SCALAR_OPERATOR
template <typename Ty, int N, int M>
FixedSizeMatrix<Ty, N, M> operator+(Ty scalar, FixedSizeMatrix<Ty, N, M> a) {
return a + scalar;
}
template <typename Ty, int N, int M>
FixedSizeMatrix<Ty, N, M> operator*(Ty scalar, FixedSizeMatrix<Ty, N, M> a) {
return a * scalar;
}
template <typename Ty, int N, int M>
FixedSizeMatrix<Ty, N, M> operator-(Ty scalar, FixedSizeMatrix<Ty, N, M> a) {
return -a + scalar;
}
template <typename Ty, int N>
FixedSizeMatrix<Ty, N, N> identity_matrix() {
FixedSizeMatrix<Ty, N, N> result = FixedSizeMatrix<Ty, N, N>();
for (int i = 0; i < N; ++i) result.element[i][i] = Ty(1);
return result;
}
template <typename Ty, int N, int M>
std::ostream& operator<<(std::ostream& out, FixedSizeMatrix<Ty, N, M> a) {
for (int r = 0; r < N; ++r) {
for (int c = 0; c < M; ++c) {
out << a.element[r][c] << " ";
}
out << std::endl;
}
return out;
}
#endif