There is this problem that I’ve solved in many different languages but not in Python, because it hits the timeout of 2 seconds. It misses just 5% though, so any slight performance improvement could be enough to make it pass. It is about “Disjoint Set Union (Union Find)“.
After testing dozens of versions, I selected the one below, which seems to have best performance in the online judge:
from os import read
from os import fstat
from sys import stdout as out
def inputIter():
vs = 0
for v in read(0, fstat(0).st_size):
if v >= 48:
vs *= 10
vs += v - 48
else:
yield vs
vs = 0
iterator = inputIter()
ga = 0
gb = 0
def linka(a):
va = arr[a]
mi = 0
if va == a:
global ga
ga = a
mi = linkb(gb)
else:
mi = linka(va)
arr[a] = mi
return mi
def linkb(a):
va = arr[a]
mi = 0
if va == a:
mi = min(a, ga)
else:
mi = linkb(va)
arr[a] = mi
return mi
isNotFirst = 0
try:
while 1:
students = next(iterator)
relations = next(iterator)
queries = next(iterator)
arr = [0] * (students + 1)
for i in range(1, students + 1):
arr[i] = i
for i in range(relations):
a = next(iterator)
gb = next(iterator)
linka(a)
for i in range(1,students + 1):
a = arr[i]
if a < i:
arr[i] = arr[a]
outbuff = ['N'] * queries
for i in range(queries):
if arr[next(iterator)] == arr[next(iterator)]:
outbuff[i] = 'S'
out.write("\n" * isNotFirst + "\n".join(outbuff) + "\n")
isNotFirst = 1
except StopIteration:
None
Code measurements have shown that the input reading is the bottleneck, but I can’t make it faster.
Other things I tested which seem to perform slightly worse:
- switching
[0] * n
for list comprehensions, array.array or numpy - shrink output code:
outbuff = [('S' if arr[next(iterator)] == arr[next(iterator)] else 'N') for _ in range(consultas)]
out.write("\n" * isNotFirst + "\n".join(outbuff) + "\n")
isNotFirst = 1
- make the link function non-recursive
def link2(a, b):
stack = []
va = arr[a]
while va != a:
stack.append(a)
a = va
va = arr[a]
vb = arr[b]
while vb != b:
stack.append(b)
b = vb
vb = arr[b]
root1 = min(va, vb)
root2 = max(va, vb)
arr[root2] = root1
for v in stack:
arr[v] = root1
return root1
- different approach which collects all relations and then navigates through them, producing a solution O(n):
arr = [0] * (students + 1)
rels = [[] for i in range(students + 1)]
for i in range(relations):
rels[a].append(b)
rels[b].append(a)
for i in range(1, students + 1):
linknodes(i)
recursively:
def linknodes(i):
linkrel(i, i)
def linkrel(current, groupid):
if arr[current] == 0:
arr[current] = groupid
lst = rels[current]
for friend in lst:
linkrel(friend, groupid)
non-recursive:
def linknodes(groupid):
ni = 0
nodes = [groupid]
while ni < len(nodes):
n = nodes[ni]
ni += 1
if arr[n] == 0:
arr[n] = groupid
lst = rels[n]
if len(lst) > 0:
nodes.extend(lst)
So, is there anything else I can try? Thanks in advance!