ABC 403 D - Forbidden Difference

#a3477264802e400c8248c8e1e212943c
2026.4.18
2026.4.18
  • 問題: https://atcoder.jp/contests/abc403/tasks/abc403_d

  • Aのうち、Dで割った剰余が異なる要素は互いに影響しないので、別々に考える

  • 例としてD=3、剰余が0の場合を考える

    • どの数字を削除するべきかは、出現個数によって変化する

    • 同じ数字の出現が複数ある場合、全て残すか、全て消すかのいずれかとなる

    • i番目の数字を(全て)削除する場合、削除しない場合で分岐するようなDPを考えれば良さそう、となる

  • dp[i]: i番目の数字までで条件を満たすのに、削除すべき数字の最小数

  • ベースケース: dp[0]=0

  • 遷移

    • i番目を消し、i+1番目に遷移。i番目の数字の個数がコストとして加算される

      • dp[i+1]=\min(dp[i+1],dp[i]+costs[i])

    • i番目を残す

      • i番目の数字とi+1番目の差がDであるとき: i+1番目の数字を削除しなければならない。よってi+1番目の削除コストを加えてi+2に遷移

        • dp[i+2]=\min(dp[i+2],dp[i]+costs[i+1])

      • そうでないとき: 追加コスト0でi+1に遷移

        • dp[i+1]=\min(dp[i+1],dp[i])

  • D=0の時に0割りでREとなったので、条件分岐して別のロジックで計算

    • D=0の場合も一般的に解ける解法が思いつかない

template <typename T> auto runLengthEncode(const T &container) {
  using ValueType = std::remove_cvref_t<decltype(*std::begin(container))>;

  std::vector<std::pair<ValueType, llong>> result;

  auto it = std::begin(container);
  auto end = std::end(container);

  if (it == end) {
    return result;
  }

  ValueType last = *it;
  llong length = 0;

  for (; it != end; ++it) {
    if (*it == last) {
      length++;
    } else {
      result.emplace_back(last, length);
      last = *it;
      length = 1;
    }
  }

  result.emplace_back(last, length);
  return result;
}

void answer() {
  llong N, D;
  read(N, D);
  std::vector<llong> A(N);
  read(A);

  std::sort(A.begin(), A.end());
  auto encoded = runLengthEncode(A);

  if (D == 0) {
    llong count = 0;

    for (auto &&[A_i, c] : encoded) {
      count += c - 1;
    }

    writeln(count);
    return;
  }

  std::unordered_map<llong, std::vector<std::pair<llong, llong>>> mod;

  for (auto &&[A_i, count] : encoded) {
    mod[A_i % D].push_back({A_i, count});
  }

  llong count = 0;

  for (auto &&[rem, list] : mod) {
    std::vector<llong> table(list.size() + 1,
                             std::numeric_limits<llong>::max() / 2);
    table[0] = 0;

    for (llong i = 0; i < list.size(); i++) {
      if (i + 1 < list.size() && list[i + 1].first == list[i].first + D) {
        table[i + 1] = std::min(table[i + 1], table[i] + list[i].second);

        if (i + 2 < table.size()) {
          table[i + 2] = std::min(table[i + 2], table[i] + list[i + 1].second);
        }
      } else {
        table[i + 1] = std::min(table[i + 1], table[i]);
      }
    }

    count += table[list.size()];
  }

  writeln(count);
}