|
|
@ -497,18 +497,16 @@ private: |
|
|
|
case NodeType::JUST_1: return {0, 0, {}}; |
|
|
|
case NodeType::JUST_0: return {0, {}, 0}; |
|
|
|
case NodeType::THRESH: { |
|
|
|
uint32_t stat = 0, dsat = 0; |
|
|
|
int32_t sat_sum = 0; |
|
|
|
std::vector<int32_t> diffs; |
|
|
|
uint32_t stat = 0; |
|
|
|
auto sats = Vector(internal::MaxInt<uint32_t>(0)); |
|
|
|
for (const auto& sub : subs) { |
|
|
|
stat += sub->ops.stat + 1; |
|
|
|
dsat += sub->ops.dsat.value; // The type system requires "d" for thresh, so dsat must always be valid.
|
|
|
|
if (sub->ops.sat.valid) diffs.push_back((int32_t)sub->ops.sat.value - (int32_t)sub->ops.dsat.value); |
|
|
|
auto next_sats = Vector(sats[0] + sub->ops.dsat); |
|
|
|
for (size_t j = 1; j < sats.size(); ++j) next_sats.push_back(Choose(sats[j] + sub->ops.dsat, sats[j - 1] + sub->ops.sat)); |
|
|
|
next_sats.push_back(sats[sats.size() - 1] + sub->ops.sat); |
|
|
|
sats = std::move(next_sats); |
|
|
|
} |
|
|
|
if (diffs.size() < k) return {stat, {}, dsat}; |
|
|
|
std::sort(diffs.begin(), diffs.end()); |
|
|
|
for (size_t i = diffs.size() - k; i < diffs.size(); ++i) sat_sum += diffs[i]; |
|
|
|
return {stat, sat_sum + dsat, dsat}; |
|
|
|
return {stat, sats[k], sats[0]}; |
|
|
|
} |
|
|
|
} |
|
|
|
assert(false); |
|
|
@ -543,17 +541,14 @@ private: |
|
|
|
case NodeType::JUST_1: return {0, {}}; |
|
|
|
case NodeType::JUST_0: return {{}, 0}; |
|
|
|
case NodeType::THRESH: { |
|
|
|
uint32_t dsat = 0; |
|
|
|
int32_t sat_sum = 0; |
|
|
|
std::vector<int32_t> diffs; |
|
|
|
auto sats = Vector(internal::MaxInt<uint32_t>(0)); |
|
|
|
for (const auto& sub : subs) { |
|
|
|
dsat += sub->ss.dsat.value; // The type system requires "d" for thresh, so dsat must always be valid.
|
|
|
|
if (sub->ss.sat.valid) diffs.push_back((int32_t)sub->ss.sat.value - (int32_t)sub->ss.dsat.value); |
|
|
|
auto next_sats = Vector(sats[0] + sub->ss.dsat); |
|
|
|
for (size_t j = 1; j < sats.size(); ++j) next_sats.push_back(Choose(sats[j] + sub->ss.dsat, sats[j - 1] + sub->ss.sat)); |
|
|
|
next_sats.push_back(sats[sats.size() - 1] + sub->ss.sat); |
|
|
|
sats = std::move(next_sats); |
|
|
|
} |
|
|
|
if (diffs.size() < k) return {{}, dsat}; |
|
|
|
std::sort(diffs.begin(), diffs.end()); |
|
|
|
for (size_t i = diffs.size() - k; i < diffs.size(); ++i) sat_sum += diffs[i]; |
|
|
|
return {sat_sum + dsat, dsat}; |
|
|
|
return {sats[k], sats[0]}; |
|
|
|
} |
|
|
|
} |
|
|
|
assert(false); |
|
|
|