add tie-breaking test

release/4.3a0
Varun Agrawal 2024-07-15 17:46:26 -04:00
parent 4a04963197
commit 83eff969c5
1 changed files with 16 additions and 10 deletions

View File

@ -290,26 +290,32 @@ TEST(DiscreteConditional, choose) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Check argmax on P(C|D) and P(D) // Check argmax on P(C|D) and P(D), plus tie-breaking for P(B)
TEST(DiscreteConditional, Argmax) { TEST(DiscreteConditional, Argmax) {
DiscreteKey C(2, 2), D(4, 2); DiscreteKey B(2, 2), C(2, 2), D(4, 2);
DiscreteConditional B_prior(D, "1/1");
DiscreteConditional D_prior(D, "1/3"); DiscreteConditional D_prior(D, "1/3");
DiscreteConditional C_given_D((C | D) = "1/4 1/1"); DiscreteConditional C_given_D((C | D) = "1/4 1/1");
// Case 1: No parents // Case 1: Tie breaking
size_t actual1 = D_prior.argmax(); size_t actual1 = B_prior.argmax();
EXPECT_LONGS_EQUAL(1, actual1); // In the case of ties, the first value is chosen.
EXPECT_LONGS_EQUAL(0, actual1);
// Case 2: No parents
size_t actual2 = D_prior.argmax();
// Selects 1 since it has 0.75 probability
EXPECT_LONGS_EQUAL(1, actual2);
// Case 2: Given parent values // Case 3: Given parent values
DiscreteValues given; DiscreteValues given;
given[D.first] = 1; given[D.first] = 1;
size_t actual2 = C_given_D.argmax(given); size_t actual3 = C_given_D.argmax(given);
// Should be 0 since D=1 gives 0.5/0.5 // Should be 0 since D=1 gives 0.5/0.5
EXPECT_LONGS_EQUAL(0, actual2); EXPECT_LONGS_EQUAL(0, actual3);
given[D.first] = 0; given[D.first] = 0;
size_t actual3 = C_given_D.argmax(given); size_t actual4 = C_given_D.argmax(given);
EXPECT_LONGS_EQUAL(1, actual3); EXPECT_LONGS_EQUAL(1, actual4);
} }
/* ************************************************************************* */ /* ************************************************************************* */