본문 바로가기
Algorithm/Graph

[파이썬] 백준 1774 : 우주신과의 교감 (골드3)

by 베짱이28호 2023. 12. 13.

[파이썬] 백준 1774 : 우주신과의 교감 (골드3)


문제


풀이

0. 방향성 생각

  • MST
  • 미리 연결된 간선이 들어오면 먼저 union해주기

1. 입력

import sys
input = lambda : sys.stdin.readline().rstrip()

n,m = map(int,input().split())
locations = [None] + [tuple(map(int,input().split())) for _ in range(n)]

parent = list(range(n+1))
for _ in range(m):
    x,y = map(int,input().split())
    union(x,y)
  • location에 x,y 위치 받아주기
  • 연결된 간선끼리는 union 해주기.

2. 함수 정의

def find(a):
    if parent[a] != a:
        parent[a] = find(parent[a])
    return parent[a]
 
def union(a,b):
    pa,pb = find(a),find(b)
    if pa != pb:
        parent[max(pa,pb)] = min(pa,pb)

def cal(a,b):
    xa,ya = locations[a]
    xb,yb = locations[b]
    return ((xa-xb)**2+(ya-yb)**2)**0.5
  • 부모찾기, 유니온, 거리계산 함수 정의

3. 출력

edges = {}
for i in range(1,n+1):
    for j in range(1,n+1):
        edges[(j,i)] = cal(j,i)
            
edges = list(edges.items())
edges.sort(key = lambda x:-x[1])

answer = 0
while edges:
    nodes,cost = edges.pop()
    x,y = nodes
    if find(x) != find(y):
        union(x,y)
        answer += cost
print(f'{answer:.2f}')
  • 가벼운 간선부터 뽑아주면서 서로 다른 집합일 경우 union
  • 그냥 sort시키고 for문 하는게 더 빠를듯 (while pop이 더 빨리나옴)

 


전체코드

import sys
input = lambda : sys.stdin.readline().rstrip()

def find(a):
    if parent[a] != a:
        parent[a] = find(parent[a])
    return parent[a]
 
def union(a,b):
    pa,pb = find(a),find(b)
    if pa != pb:
        parent[max(pa,pb)] = min(pa,pb)

def cal(a,b):
    xa,ya = locations[a]
    xb,yb = locations[b]
    return ((xa-xb)**2+(ya-yb)**2)**0.5

n,m = map(int,input().split())
locations = [None] + [tuple(map(int,input().split())) for _ in range(n)]

parent = list(range(n+1))
for _ in range(m):
    x,y = map(int,input().split())
    union(x,y)
    
edges = {}
for i in range(1,n+1):
    for j in range(1,n+1):
        edges[(j,i)] = cal(j,i)
            
edges = list(edges.items())
edges.sort(key = lambda x:-x[1])

answer = 0
while edges:
    nodes,cost = edges.pop()
    x,y = nodes
    if find(x) != find(y):
        union(x,y)
        answer += cost
print(f'{answer:.2f}')

 

코멘트

**0.5로 루트씌우는건데 0.5 곱해버림......

댓글