본문 바로가기
Algorithm/Graph

[파이썬] 백준 23034 : 조별과제 멈춰 (플레5)

by 베짱이28호 2023. 12. 22.

[파이썬] 백준 23034 : 조별과제 멈춰 (플레5)

 

23034번: 조별과제 멈춰!

교수님이 시험 기간에 조별 과제를 준비하셨다...! 가톨릭대학교의 조교 아리는 N명의 학생을 2개의 조로 구성하여 과제 공지를 하려 한다. 이때, 구성된 각 조의 인원은 1명 이상이어야 한다. 각

www.acmicpc.net


문제


풀이

0. 방향성 생각

  • 2개의 집합으로 분리 + 최소비용으로 간선 잇기 = 크루스칼
  • 두 간선 사이 연결 끊기 : a 노드부터 b 노드까지 간선 중 최대값을 끊으면 최소 비용으로 두 집합 분리 가능
  • a,b가 10000개 이상 주어지니, 시간초과 해결을 위해서 a와b 사이 최대값을 구하는 list를 만들어준다.

1. 입력

from collections import deque
import sys
input = lambda : sys.stdin.readline().rstrip()

n,m = map(int,input().split())
inf = 1e9
graph = [[] for _ in range(n+1)]
edges = []
for _ in range(m):
    a,b,c = map(int,input().split())
    graph[a].append((b,c))
    graph[b].append((a,c))
    edges.append((c,a,b))
  • MST를 구성하는데 쓸 edges와 연결정보를 넣은 graph

2. MST

parent = list(range(n+1))
def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]

def union(x,y):
    parent[max(x,y)] = min(x,y)

total_cost = 0
mst_graph = [[] for _ in range(n+1)]
edges.sort(reverse=True)
while edges:
    cost,a,b = edges.pop()
    pa,pb = find(a),find(b)
    if pa != pb:
        union(pa,pb)
        total_cost += cost
        mst_graph[a].append((b,cost))
        mst_graph[b].append((a,cost))
  • 크루스칼 돌려서 MST 구성해주기.
  • MST를 구성하는데 든 총 비용은 total_cost에 저장

3. 출력

dp = [[0]*(n+1) for _ in range(n+1)]
def bfs(s):
    q = deque([s])
    visit = [False]*(n+1)
    visit[s] = True
    while q:
        x = q.popleft()
        for nx,cost in mst_graph[x]:
            if not visit[nx]:
                dp[s][nx] = max(dp[s][x],cost)
                visit[nx] = True
                q.append(nx)

for i in range(n):
    bfs(i+1)

for _ in range(int(input())):
    a,b = map(int,input().split())
    print(total_cost-dp[a][b])
  • 각 노드에서 BFS를 실행해서 다른 노드까지 얼마나 걸리는지 계산하기.
  • s에서 nx까지 갈 때, 최대값을 dp[s][nx]에 넣어준다. 이전까지 값인 dp[s][x] or 현재 간선정보  cost 중 최대값 선택.

 


전체코드

from collections import deque
import sys
input = lambda : sys.stdin.readline().rstrip()

n,m = map(int,input().split())
inf = 1e9
graph = [[] for _ in range(n+1)]
edges = []
for _ in range(m):
    a,b,c = map(int,input().split())
    graph[a].append((b,c))
    graph[b].append((a,c))
    edges.append((c,a,b))

parent = list(range(n+1))
def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]

def union(x,y):
    parent[max(x,y)] = min(x,y)

total_cost = 0
mst_graph = [[] for _ in range(n+1)]
edges.sort(reverse=True)
while edges:
    cost,a,b = edges.pop()
    pa,pb = find(a),find(b)
    if pa != pb:
        union(pa,pb)
        total_cost += cost
        mst_graph[a].append((b,cost))
        mst_graph[b].append((a,cost))

dp = [[0]*(n+1) for _ in range(n+1)]
def bfs(s):
    q = deque([s])
    visit = [False]*(n+1)
    visit[s] = True
    while q:
        x = q.popleft()
        for nx,cost in mst_graph[x]:
            if not visit[nx]:
                dp[s][nx] = max(dp[s][x],cost)
                visit[nx] = True
                q.append(nx)

for i in range(n):
    bfs(i+1)

for _ in range(int(input())):
    a,b = map(int,input().split())
    print(total_cost-dp[a][b])

 

코멘트

처음 풀이

from collections import deque
import sys
input = lambda : sys.stdin.readline().rstrip()
n,m = map(int,input().split())

inf = 1e9
edges = [[inf]*(n+1) for _ in range(n+1)]
graph = [[] for _ in range(n+1)]
temp = []
for _ in range(m):
    a,b,c = map(int,input().split())
    edges[a][b] = c
    edges[b][a] = c
    graph[a].append(b)
    graph[b].append(a)
    temp.append((c,a,b))

parent = list(range(n+1))
def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]

def union(x,y):
    parent[max(x,y)] = min(x,y)

mst_graph = [[] for _ in range(n+1)]
temp.sort(reverse=True)
answer = 0
while temp:
    cost,a,b = temp.pop()
    pa,pb = find(a),find(b)
    if pa != pb:
        union(pa,pb)
        answer += cost
        mst_graph[a].append(b)
        mst_graph[b].append(a)

for idx,graph in enumerate(mst_graph):
    if len(graph) == 1:
        start = idx
        break

visit = [False]*(n+1)
visit[start] = True
q = deque([start])
trans = {start:0}
mst_cost = []
while q:
    x = q.popleft()
    for nx in mst_graph[x]:
        if not visit[nx]:
            q.append(nx)
            visit[nx] = True
            trans[nx] = trans[x] + 1
            mst_cost.append(edges[x][nx])

section_max = [[0]*(n-1) for _ in range(n-1)]
for i in range(n-1):
    for j in range(i,n-1):
        section_max[i][j] = max(section_max[i][j-1], mst_cost[j])

t = int(input())
for _ in range(t):
    a,b = map(int,input().split())
    a,b = trans[a],trans[b]
    if a > b: a,b = b,a
    print(answer - section_max[a][b-1])
  • MST를 구성하고 난 후에, 양 끝 지점중 하나를 찾아서 일렬로 펼치는 작업을 진행한다.
  • 각 노드 번호를 trans를 거쳐서 0부터 n-1까지 매핑한다.
  • mst_cost는 끝 지점에서 BFS를 돌리면서 생긴 간선들을 하나씩 더해준 것.
  • 마찬가지로 dp table을 업데이트 한다.
  • 노드 N개를 순회하는데 N^2걸리고 틀린 부분은 없어보이는데 그냥 열받아서 BFS로 돌렸음.

댓글