refactor HybridValues

release/4.3a0
Varun Agrawal 2022-08-26 19:35:45 -04:00
parent 1b5daf9a0e
commit 9c7bf36db6
1 changed files with 36 additions and 18 deletions

View File

@ -36,40 +36,58 @@ namespace gtsam {
* Optimizing a HybridGaussianBayesNet returns this class. * Optimizing a HybridGaussianBayesNet returns this class.
*/ */
class GTSAM_EXPORT HybridValues { class GTSAM_EXPORT HybridValues {
public: private:
// DiscreteValue stored the discrete components of the HybridValues. // DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete; DiscreteValues discrete_;
// VectorValue stored the continuous components of the HybridValues. // VectorValue stored the continuous components of the HybridValues.
VectorValues continuous; VectorValues continuous_;
public:
/// @name Standard Constructors
/// @{
// Default constructor creates an empty HybridValues. // Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous(){}; HybridValues() = default;
// Construct from DiscreteValues and VectorValues. // Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv) HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){}; : discrete_(dv), continuous_(cv){};
/// @}
/// @name Testable
/// @{
// print required by Testable for unit testing // print required by Testable for unit testing
void print(const std::string& s = "HybridValues", void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n"; std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components discrete_.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous", continuous_.print(" Continuous",
keyFormatter); // print continuous components keyFormatter); // print continuous components
}; };
// equals required by Testable for unit testing // equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const { bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) && return discrete_.equals(other.discrete_, tol) &&
continuous.equals(other.continuous, tol); continuous_.equals(other.continuous_, tol);
} }
/// @}
/// @name Interface
/// @{
/// Return the discrete MPE assignment
DiscreteValues discrete() const { return discrete_; }
/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }
// Check whether a variable with key \c j exists in DiscreteValue. // Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
// Check whether a variable with key \c j exists in VectorValue. // Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous.exists(j); }; bool existsVector(Key j) { return continuous_.exists(j); };
// Check whether a variable with key \c j exists. // Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
@ -78,13 +96,13 @@ class GTSAM_EXPORT HybridValues {
* the key \c j is already used. * the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; }; void insert(Key j, int value) { discrete_[j] = value; };
/** Insert a vector \c value with key \c j. Throws an invalid_argument /** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used. * exception if the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); } void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues {
* Read/write access to the discrete value with key \c j, throws * Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
size_t& atDiscrete(Key j) { return discrete.at(j); }; size_t& atDiscrete(Key j) { return discrete_.at(j); };
/** /**
* Read/write access to the vector value with key \c j, throws * Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
Vector& at(Key j) { return continuous.at(j); }; Vector& at(Key j) { return continuous_.at(j); };
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues {
std::string html( std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss; std::stringstream ss;
ss << this->discrete.html(keyFormatter); ss << this->discrete_.html(keyFormatter);
ss << this->continuous.html(keyFormatter); ss << this->continuous_.html(keyFormatter);
return ss.str(); return ss.str();
}; };