htkb-proconの日記

初心者がPythonで問題解いた記録

AtCoder青になるまで使ったPythonパフォーマンス小ネタ

f:id:htkb:20190107210932j:plain

AtCoder青になったらなんか書こうと思っていて、12/22に無事青になれたのですが、よく考えたらあんま書くことないな……と思ってる間に年が開けてしまいました。知識披露でもサクセスストーリー的な話でも既に109+7人くらいいる青コーダー達が語り尽くしてしまっているのですが、知識の整理を兼ねてPythonという遅い言語で競プロをやる上でのパフォーマンス関連の小ネタを少しまとめてみました。きっとどこかで見たような話ばっかりですが。

残念ながらPython競プロにおいて重要なnumpy/scipyネタは語れるほど知らないのでありません。LLだとnumpy/scipy使わないとほぼ絶対通せない系の問題がたまにあるのでいずれちゃんとやらねばと思っていますが……。

以下でのベンチの数字は Core i5-8259U, DDR4-2400 8G, Ubuntu 18.10 x64, Python 3.7.0 でのデータです。


リストのindexingを極力避ける

リストaに対してa[0]のようにしてアクセスするアレですが、何故かPythonはこれが異様に遅いです。

a = list(range(10**6))

%%timeit
ans = 0
for i in range(len(a)):
    if a[i] % 2 == 0:
        ans += 1

リストa内に偶数の要素が何個あるか愚直に数えている感じですが、これをipythonで実行すると

72.7 ms ± 910 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

こんな感じに実行時間が出てきます。ではaの要素を直接変数で受けて判定したらどうでしょう。

%%timeit
ans = 0
for n in a:
    if n % 2 == 0:
        ans += 1

45.7 ms ± 282 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

と6割くらいになりました。実は最初の書き方、

%%timeit
ans = 0
getitem = a.__getitem__
for i in range(len(a)):
    if getitem(i) % 2 == 0:
        ans += 1

76.8 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

特殊メソッドを呼んで要素を得るのに近いくらい遅いです。なんとなくカジュアルにa[0]とかやってたのが関数呼び出しくらい遅いとかヤバヤバですよね。インデックスが欲しい場合でもenumeratezipでインデックス用のrangeをまとめて回すなどしたほうがかなり速いです。

これは差が大きいのでどのような用途であれ不必要にリストを直接触るのを避けたほうがいい(受ける変数名に適切な名前を付ければコードの見通しも良くなる)ですが、特に滅茶苦茶影響が大きいアルゴリズムとしてワーシャルフロイド法があります。

D - バスと避けられない運命
Submission #3917065 - AtCoder Beginner Contest 012
Submission #3917007 - AtCoder Beginner Contest 012

ほぼワーシャルフロイドやるだけの問題でTLは5秒。上は教科書的なワーシャルフロイドで下はindexingを極力避けて最適化したものですが、最適化版は3秒でACなのが普通にやると5秒あってもTLEです。ワシャフロはソラで書ける実装の軽さで女子高生にも人気ですが、Pythonの場合は最適化したものをライブラリに持っておいたほうがいいです(上のリンク先の自作ワシャフロ使っていいよ!エッヘン)。


組み込み関数・標準ライブラリを使いこなそう

PythonはCで書かれており、組み込み関数や標準ライブラリで完結する処理はCにやらせてるようなもん(超適当)なので、Pythonの世界内で自前実装するより遥かに高速であることが多いです。

In [71]: %%timeit
    ...: ans = 1
    ...: for i in range(2, 10**5+1):
    ...:     ans *= i
    ...:
1.82 s ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [72]: %%timeit
    ...: ans = math.factorial(10**5)
    ...: 
108 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

105!は456573桁にも上るとんでもない数ですが、math.factorialを使うと一度きりなら十分高速にmod無しで計算しつくすこともできます(メモリもそれほど食いません)。さすがにバカでかい数を何度も取り回してると遅いので、繰り返し階乗を求める場合はループで逐一modを取ったり1からNまでの階乗をループ回しながらリストに保存していくなどする必要があり、そういった標準で提供されない機能に限り自分で実装するわけですが、意外と競プロ的に都合のいい関数やオプションが標準で用意されていたりします。

お恥ずかしながら私も最近まで知らなかったのですが、Pythonのべき乗を求めるpow()は第三引数にmodを与えることでmodを取りながら計算することができ、おそらく内部的に繰り返し二乗法のような効率的な実装がなされているので、繰り返し二乗法も自分で実装する必要はありません。

In [88]: %%timeit -n1 -r1
    ...: # 12345678 ^ 1000000000000 % 1000000007
    ...: n, exp, mod = 12345678, 10**12, 10**9+7
    ...: ans = 1
    ...: 
    ...: while exp:
    ...:     if exp & 1:
    ...:         ans = (ans * n) % mod
    ...:     n = (n * n) % mod
    ...:     exp >>= 1
    ...: print(ans)
    ...:
174484875
20.3 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [89]: %%timeit -n1 -r1
    ...: print(pow(12345678, 10**12, 10**9+7))
    ...: 
174484875
14.9 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

他によく使う例として、二点間のユークリッド距離を求めるときにピタゴラスの定理を使いますが、自分で二乗してルートしたりしなくてもmath.hypot()にx軸の距離とy軸の距離をぶん投げるだけでよく、距離がマイナスでもよしなに取り計らってくれます。

In [94]: math.hypot(1, -1)
Out[94]: 1.4142135623730951 

小数を返す関数は、小数の精度の面でも自前であれこれやるよりメリットがあります。

上で述べたリストの参照回数を減らすためのenumeratezipなども含め、Pythonが用意してくれている便利機能を知り使いこなすことは競プロにおいても非常に重要です。競プロ中級以降へ進み速度面からC++へ移行したとしても、幅広い知識を持って使うPythonは素早い雑魚散らしに変わらず役立つことでしょう。公式リファレンスは一通り読んでおいて絶対に損はありません。

Python 標準ライブラリ

ちなみに、math.gcd()など一部の比較的新しいバージョンで追加された関数は、Python3.4環境であるAtCoderでは使えません。3.5から追加された機能は多いのでバージョン問題には気をつけてください。


itertoolsを使いこなそう(ただしパフォーマンス上の懸念あり)

上と被りますが、Pythonのとっても便利なライブラリitertoolsイテレータをうまいことアレコレしてくれる便利なツールで、特に競プロ的には便利極まるものです。

In [31]: for a in itertools.permutations(range(3)):
    ...:     print(a)
    ...: 
(0, 1, 2)
(0, 2, 1)
(1, 0, 2)
(1, 2, 0)
(2, 0, 1)
(2, 1, 0)

In [32]: for a in itertools.combinations(range(3), r=2):
    ...:     print(a)
    ...: 
(0, 1)
(0, 2)
(1, 2)

In [33]: for y, x in itertools.product(range(3), repeat=2):
    ...:     print(f"x={x}, y={y}")
    ...: 
x=0, y=0
x=1, y=0
x=2, y=0
x=0, y=1
x=1, y=1
x=2, y=1
x=0, y=2
x=1, y=2
x=2, y=2

これらを使わなくても愚直に書けますが、多重ループにした上に重複チェックが必要になったりとネストが深く煩雑になりがちなところをスッキリ書けるメリットは大きく、使えるところでは積極的に使っていきたいライブラリです。

また、itertools.permutationsitertools.combinationsでは重複のないイテレータを返してくれるため、自前でループ内で重複チェックするよりかなり定数倍が抑えられるのでは?と思い、もともとこの記事を書こうと思ったときはパフォーマンス最適化の側面からもこれらのライブラリを推そうと思っていました。ところがコードを試してみると……

てな感じで愚直に多重ループなどで書き下すより遅くなる場合があることが判明してしまいました。リプにてヒントを頂けました(はっほーさんありがとうございます!)が、どうもitertoolsのこれらの関数は引数を展開して内部に持っておくような感じで、極端な例でいうと

f:id:htkb:20190107140009j:plain

rangeを先にリストに展開している3番目のケースとitertoolsを使った左2ケースのメモリ使用量に注目してください。内部的には同じように展開して保持しているものと思われるため、ある程度以上の長さのイテレータ等を投げるとメモリ使用量やメモリアクセスが足を引っ張るような感じです。というか展開しないと順列や組み合わせを作れないのでよく考えればあたりまえの挙動な気もしますが……。

実際のところこんなに長いrangeイテレータを与えることはなく、もしそういうケースがあるならそもそもPythonではどうやってもTLEなので、基本的にはitertoolsを使って楽をしていくのがいいと思いますが、このようなメモリ上の問題があることは知った上で使ったほうがよさそうです。特にどうもproductは遅いようなので定数倍が気になる問題では避けたほうがいいのかも。便利なんですけどね。


属性アクセスをキャッシュする

例えばリストのお尻にどんどん要素を追加していく場合、ループ内でa.append(x)などとするわけですが、ループ外でa.appendを変数に取っておくと少し速くなります。

%%timeit
a = []
for i in range(10**6):
    a.append(i)

62.1 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%%timeit
a = []
append = a.append
for i in range(10**6):
    append(i)

47.1 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

この差は何でしょうか?私は低レイヤわからないマンですが、disライブラリを使って逆アセンブルのまねごとをしてみます。

In [50]: import dis
In [51]: a = []

In [52]: append = a.append

In [53]: dis.dis(lambda: a.append(1))
  1           0 LOAD_GLOBAL              0 (a)
              2 LOAD_METHOD              1 (append)
              4 LOAD_CONST               1 (1)
              6 CALL_METHOD              1
              8 RETURN_VALUE

In [54]: dis.dis(lambda: append(1))
  1           0 LOAD_GLOBAL              0 (append)
              2 LOAD_CONST               1 (1)
              4 CALL_FUNCTION            1
              6 RETURN_VALUE

a.append(1)のほうはグローバル変数aを読み込んだ後、appendメソッドをさらに探しています。ここで初めて指定のメソッドが存在しなければAttributeErrorを吐くなどのプロセスも踏んでいるでしょう(lambda定義時点では存在しないメソッドを指定しても怒られないため)。一方でappend(1)のほうはグローバル変数appendを読み込むだけで終わりです。つまりメソッドの名前探索やAttributeErrorの処理が全部終わった状態から始められるということです。

リストのindexing問題ほど生死を分かつようなことはあまりない気がしていますが、全探索やグラフの最短経路探索などでキューやスタックをぶん回すときは、おまじないとしてループ外でキャッシュしておくといいでしょう。また優先度キューでheapqを使うときや二分探索でbisectを使うときは

from heapq import heappush, heappop
from bisect import bisect_left

のようにしてimportすると同じ効果が得られ手間が省けます。ただし、import *はお行儀が悪いのでおすすめしません。競プロerがお行儀を語るとか噴飯ものと言われるかもしれませんが。


クラスのメソッドと関数の速度差

主にデータ構造は内部に実データを持つと取り回しが楽なのでクラスを定義して使いたいわけですが、鈍足のPythonではどうしてもパフォーマンスへの影響が気になるもの。ですが、結論から言うと若干の差はあるものの気にするレベルではありません。例としてUnionFind木のクラス版と関数版を比較してみます。

クラス版

class UnionFind(object):
    __slots__ = ["nodes"]

    def __init__(self, n: int):
        self.nodes = [-1]*n

    def get_root(self, x: int) -> int:
        if self.nodes[x] < 0:
            return x
        else:
            self.nodes[x] = self.get_root(self.nodes[x])
            return self.nodes[x]

    def unite(self, x: int, y: int) -> None:
        root_x, root_y = self.get_root(x), self.get_root(y)
        if root_x != root_y:
            bigroot, smallroot = \
                (root_x, root_y) if self.nodes[root_x] < self.nodes[root_y] else (root_y, root_x)
            self.nodes[bigroot] += self.nodes[smallroot]
            self.nodes[smallroot] = bigroot

関数版

def get_root(nodes, x: int) -> int:
    if nodes[x] < 0:
        return x
    else:
        nodes[x] = get_root(nodes, nodes[x])
        return nodes[x]


def unite(nodes, x: int, y: int) -> None:
    root_x, root_y = get_root(nodes, x), get_root(nodes, y)
    if root_x != root_y:
        bigroot, smallroot = \
            (root_x, root_y) if nodes[root_x] < nodes[root_y] else (root_y, root_x)
        nodes[bigroot] += nodes[smallroot]
        nodes[smallroot] = bigroot

クラス版の__slots__属性は、これがあるとここに列挙されている名前以外のインスタンス変数を取れなくなり、動的に追加することができなくなる代わりにメモリ消費量やパフォーマンスが改善されるというものです。この__slots__の有無についても効果を見てみます。

AtCoder Typical Contest 001 B - Union Find

集合の連結と連結判定を求められる、ライブラリのverifyに使える問題です。頂点数<=100,000, クエリ数<=200,000

クラス(__slots__あり) 515ms
Submission #3955762 - AtCoder Typical Contest 001

クラス(__slots__なし) 507ms
Submission #3955758 - AtCoder Typical Contest 001

関数版 430ms
Submission #3955770 - AtCoder Typical Contest 001

互いに素な集合 Union Find| データ構造ライブラリ | Aizu Online Judge

求められる出力フォーマットが違うだけでそれ以外は上と全く同じ問題。頂点数<=10,000, クエリ数<=100,000

クラス(__slots__あり) 0.34s
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=3326955

クラス(__slots__なし) 0.40s
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=3326956

関数版 0.32s
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=3326957

__slots__の効果が微妙な感じだったのが残念ですが(これも最適化ネタとして紹介したかったのに……)、とりあえず現実的な制約下ではクラスのメソッドを20万回くらい叩いても0.5秒程度、関数版との差は0.1秒未満なので、生のリストを持って関数に投げるような面倒な取り回しをするよりもクラス化してしまったほうがいいと思います。


for _ in [0]*x?

In [61]: %%timeit
    ...: for i in range(10**8):
    ...:     1+1
    ...:
1.46 s ± 74.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [62]: %%timeit
    ...: for _ in [0]*(10**8):
    ...:     1+1
    ...:
1 s ± 7.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [63]: %%timeit
    ...: for _ in [None]*(10**8):
    ...:     1+1
    ...:
1 s ± 9.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 

ループカウンタを参照しないときにrangeを回すより[0]*xを回したほうがほんの少し速いです。[None]のほうがさらに微妙に最速だったと記憶していましたが違いは出ませんでした。ぶっちゃけ無意味で邪悪なので使わなくていいです。ただ参照しないループカウンタは_で参照しないことを明示するのは有用と思います。IDEも先頭が_の変数名は未使用でも怒らなかったりするんでpepかなんかに書いてある気がします(適当


割り算を避けても速くならない

多くの言語では割り算や剰余計算が他の計算より桁一つ以上遅かったりするので、定数倍が気になる場面では偶奇判定に1とのandを取ったり、2で割って切り捨てるのをビットシフトに置き換えたりするテクがありますが、Pythonの割り算や剰余計算は他の演算とほぼ等速なのでそのような工夫をしても速くなりません。

In [11]: %%timeit -n1 -r1
    ...: ans = 0
    ...: for i in range(10**7):
    ...:     ans += i % 2
    ...: print(ans)
    ...: 
5000000
483 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [12]: %%timeit -n1 -r1
    ...: ans = 0
    ...: for i in range(10**7):
    ...:     ans += i & 1
    ...: print(ans)
    ...: 
5000000
550 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [13]: %%timeit -n1 -r1
    ...: ans = 0
    ...: for i in range(10**7):
    ...:     ans += i // 2
    ...: print(ans)
    ...: 
24999995000000
576 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [14]: %%timeit -n1 -r1
    ...: ans = 0
    ...: for i in range(10**7):
    ...:     ans += i >> 1
    ...: print(ans)
    ...: 
24999995000000
593 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

ビット演算は一般的に速いイメージですが、Pythonだと何故か割り算を含む四則演算よりほんの少し遅い感じなので、わざわざ避けるほどではないですが速度を稼ぎたいときにビット演算に直すテクは有効ではありません。