harurun競プロ

python勢

ABC243E - Edge Deletion をダイクストラ法で通す

atcoder.jp

問題文

N 頂点 M 辺の単純連結無向グラフが与えられます。 辺 i は頂点 A_i と頂点 B_i を結ぶ長さ C_i の辺です。

以下の条件を満たすようにいくつかの辺を削除します。削除する辺の数の最大値を求めてください。

  • 辺を削除した後のグラフも連結である。
  • 全ての頂点対 (s,t) について、頂点 s と頂点 t の間の距離が削除前と削除後で変化しない。

制約

  • 2≤N≤300
  • N−1≤M≤N(N-1)/2
  • 1≤A_i<B_i≤N
  • 1≤C_i≤109
  • i≠j ならば (A_i,B_i)≠(A_j,B_j)) である。
  • 与えられるグラフは連結である。
  • 入力はすべて整数である。

方針

(制約的にワ―シャルフロイドなのだろうと思いながら)

問題文を言い変えると、

各頂点から各頂点への最短経路で使われなかった辺を削除するときの削除する辺の数の最大値を求めよ

である。

頂点 i から各頂点への最短経路は ダイクストラ法で O(MlogN) であり、i を 1∼Nまで変化させるので、合計で O(NMlogN) である。

どの辺が使われたかについては、直前の頂点を覚えていればO(N)で判定可能である。

これを実装すると以下になり、WAを喰らう。

#include <bits/stdc++.h>
using namespace std;
using ll=long long;

struct Dijkstra{
  private:
    long long INF=1000000000000000000LL;
    std::vector<std::vector<std::pair<long long,long long>>> G;
    int V;
    inline void _add(long long s,long long t,long long c){
      G[s].emplace_back(std::make_pair(t,c));
    }
  public:
    Dijkstra(int N):V(N) {
      G.resize(V);
    }
    void add1(long long s,long long t,long long c){
      _add(s,t,c);
    }
    void add2(long long s,long long t,long long c){
      _add(s,t,c);
      _add(t,s,c);
    }
    std::vector<long long> run(long long s){
      std::vector<long long> res(V,INF);
      vector<ll> pre(V,-1);
      std::priority_queue<std::pair<long long,long long>,std::vector<std::pair<long long,long long>>,std::greater<std::pair<long long,long long>>> que;
      res[s]=0;
      que.push(std::make_pair(0,s));
      while(!que.empty()){
        std::pair<long long,long long> p=que.top();
        que.pop();
        long long v=p.second;
        if(res[v]<p.first){
          continue;
        }
        for(const std::pair<long long,long long>& e:G[v]){
          if(res[e.first]>res[v]+e.second){
            pre[e.first]=v;
            res[e.first]=res[v]+e.second;
            que.push(std::make_pair(res[e.first],e.first));
          }
        }
      }
      return pre;
    }
};

int main(){
  ll N,M;
  cin>>N>>M;
  Dijkstra ijk(N);
  ll A,B,C;
  for(int i=0;i<M;i++){
    cin>>A>>B>>C;
    ijk.add2(A-1,B-1,C);
  }
  vector<vector<ll>> d(N,vector<ll>(N,0));
  for(int i=0;i<N;i++){
    auto res=ijk.run(i);
    for(int j=0;j<N;j++){
      if(res[j]!=-1){
        d[res[j]][j]++;
        d[j][res[j]]++;
      }
    }
  }
  ll ans=0;
  for(int i=0;i<N;i++){
    for(int j=i+1;j<N;j++){
      if(d[i][j]!=0){
        ans++;
      }
    }
  }
  cout<<M-ans<<endl;
}

atcoder.jp

原因

3 3
1 3 3
1 2 1
2 3 2

のような入力があったときに、頂点 1 から 頂点 3 までの最短経路が2通りできてしまい、処理順により答えが変わってしまうことである。

これを解決するためには、頂点 v から 頂点 u への最短経路は、最短経路上により多くの頂点を含むように実装すればよい。

(u から v への最短経路上の隣接する頂点を p,q とすると、多重辺がないため、辺(p,q)は必ず使われることになる。よって頂点をより多く含むほうを最短経路とすれば良い。)

これを実装すると以下のようになり、ACが得られる。

なお、Pythonでは計算量が怪しく、TLEするかもしれないのでC++などの高速な言語を使ったほうが良い。

#include <bits/stdc++.h>
using namespace std;
using ll=long long;

struct Dijkstra{
  private:
    long long INF=1000000000000000000LL;
    std::vector<std::vector<std::pair<long long,long long>>> G;
    int V;
    inline void _add(long long s,long long t,long long c){
      G[s].emplace_back(std::make_pair(t,c));
    }
  public:
    Dijkstra(int N):V(N) {
      G.resize(V);
    }
    void add1(long long s,long long t,long long c){
      _add(s,t,c);
    }
    void add2(long long s,long long t,long long c){
      _add(s,t,c);
      _add(t,s,c);
    }
    std::vector<long long> run(long long s){
      std::vector<long long> res(V,INF);
      vector<ll> pre(V,-1);
      vector<ll> cnt(V,0);
      std::priority_queue<std::pair<long long,long long>,std::vector<std::pair<long long,long long>>,std::greater<std::pair<long long,long long>>> que;
      res[s]=0;
      que.push(std::make_pair(0,s));
      while(!que.empty()){
        std::pair<long long,long long> p=que.top();
        que.pop();
        long long v=p.second;
        if(res[v]<p.first){
          continue;
        }
        for(const std::pair<long long,long long>& e:G[v]){
          if((res[e.first]==res[v]+e.second && cnt[e.first]<cnt[v]+1) || res[e.first]>res[v]+e.second){
            pre[e.first]=v;
            cnt[e.first]=cnt[v]+1;
            res[e.first]=res[v]+e.second;
            que.push(std::make_pair(res[e.first],e.first));
          }
        }
      }
      return pre;
    }
};

int main(){
  ll N,M;
  cin>>N>>M;
  Dijkstra ijk(N);
  ll A,B,C;
  for(int i=0;i<M;i++){
    cin>>A>>B>>C;
    ijk.add2(A-1,B-1,C);
  }
  
  vector<vector<ll>> d(N,vector<ll>(N,0));
  for(int i=0;i<N;i++){
    auto res=ijk.run(i);
    for(int j=0;j<N;j++){
      if(res[j]!=-1){
        d[res[j]][j]++;
        d[j][res[j]]++;
      }
    }
  }
  ll ans=0;
  for(int i=0;i<N;i++){
    for(int j=i+1;j<N;j++){
      if(d[i][j]!=0){
        ans++;
      }
    }
  }
  cout<<M-ans<<endl;
}

atcoder.jp

追記

O(N2) のダイクストラ法を使えば O(N3) になるとの指摘があったので追記。

これならPyPyでも間に合う。

class dijkstra_V2:
  def __init__(self,V):
    self.INF=float("inf")
    self.cost=[[self.INF]*V for _ in [0]*V]
    self.V=V
    return

  def add1(self,s,t,v):
    self.cost[s][t]=v
    return

  def add2(self,s,t,v):
    self.cost[s][t]=v
    self.cost[t][s]=v
    return

  def run(self,s):
    ret=[self.INF]*self.V
    ret[s]=0
    used=[False]*self.V
    pre=[-1]*self.V
    cnt=[0]*self.V
    while 1:
      v=-1
      for u in range(self.V):
        if (not used[u]) and (v==-1 or ret[u]<ret[v]):
          v=u
      if v==-1:
        break
      used[v]=True
      for u in range(self.V):
        if (ret[u]==ret[v]+self.cost[v][u] and cnt[u]<cnt[v]+1) or (ret[u]>ret[v]+self.cost[v][u]):
          ret[u]=ret[v]+self.cost[v][u]
          pre[u]=v
          cnt[u]=cnt[v]+1
    return pre

import sys

def main():
  N,M=map(int,sys.stdin.readline().split())
  ijk=dijkstra_V2(N)
  for i in range(M):
    A,B,C=map(int,sys.stdin.readline().split())
    ijk.add2(A-1,B-1,C)
  d=[[0]*N for i in range(N)]
  for i in range(N):
    res=ijk.run(i)
    for j in range(N):
      if res[j]!=-1:
        d[min(res[j],j)][max(res[j],j)]+=1
  ans=0
  for i in range(N):
    for j in range(i+1,N):
      if d[i][j]==0 and ijk.cost[i][j]!=ijk.INF:
        ans+=1
  print(ans)
  return

main()

atcoder.jp

さらに追記

O(NMlogN) でもPyPyは間に合いました

atcoder.jp