素数を無限に生成するイテレータを定義したい

すむーずぷりんです。
イテレータとは Python で for 文でどんどん要素出せるやつです。例えば

for i in range(10):
    print(i)

のときの range(10) みたいなやつです(正確には多分 range(10) に対応するイテレータオブジェクトの iter(range(10)) が生成されてます)。これの似たようなやつを自作して、素数をどんどん取り出せるものを作ろう、というのがこの記事の目標です。

簡単なイテレータ

Pythonイテレータについては、ググるといろいろ解説が出てきます。例えば↓とか分かりやすいです。
qiita.com

まず試しに range(10) と同じ挙動をするイテレータを定義してみます。

class TenIter:
    def __init__(self):
        self.idx = -1
        self.max_idx = 10

    def __iter__(self):
        return self

    def __next__(self):
        self.idx += 1
        if self.idx >= self.max_idx:
            raise StopIteration()
        else:
            return self.idx

これを定義した上で

for i in IterTen():
    print(i)

とすると、0~9の整数が標準出力に表示されます。詳しく挙動を見てみましょう。
上のコードはおおよそ下と等価です。

ten_iter = TenIter()
try:
    while True:
        i = ten_iter.__next__()
        print(i)
except StopIteration:
    pass

まず TenIter のオブジェクトを生成したものが ten_iter に代入されます。ten_iter の __next__ メソッドが呼び出されたことにより、0 がリターンされ、i に代入されます。その後 print 文で 0 が標準出力されます。まだ例外が出ていないので、もう一度 ten_iter.__next__() が実行されます。ten_iter.idx がさらに1増えて1になるので、今度は i に 1 が代入されます。そして print 文で 1 が標準出力されます。その後もループが続き、標準出力に 2, 3, 4,... と表示されます。
ループが終わるのは ten_iter.idx が 10 になったときです。このとき ten_iter.idx >= ten_iter.max_idx (==10) が成立するため、raise StopIteration() されます。すると except 文のところに飛び、ループから出ることになります。

上記をまとめると

  1. for 文の中で__next__ が繰り返し呼ばれ、戻り値が標準出力される
  2. あるタイミングで raise StopIteration() されて終了する

となります。

ちょっといたずらして raise StopIteration() しないように TenIter() を書き換えてみます:

class TenIter2:
    def __init__(self):
        self.idx = -1
        self.max_idx = 10

    def __iter__(self):
        return self

    def __next__(self):
        self.idx += 1
        return self.idx # ここで return してしまう
#        if self.idx >= self.max_idx:
#            raise StopIteration() # もう例外は出ない...
#        else:
#            return self.idx

この状態で下を実行すると無限ループになります。

for i in TenIter2():
    print(i)

標準出力に 0, 1, 2, 3, ... と無限に表示されるようになります。こんな風に raise StopIteration() がないと際限なく処理が続いてしまいます。

さて、この記事の目標を思い出しましょう。無限に素数を生成するイテレータを作ることが目標でした。ここまでの知見から、必要なことは

  1. __next__ で「次の素数」をリターンするように実装する
  2. raise StopIteration() を「しない」

のふたつです。つまり

class PrimeIter:
    def __init__(self):
        pass # __next__ の実装に必要なもの

    def __iter__(self):
        return self

    def __next__(self):
        pass # 「次の」素数をリターンする処理

の pass の部分を適切に埋めれば完成です!

「次の素数」を見つける作戦

__next__ のメソッドをうまく定義すればよいということはわかりました。しかし「次の素数はなにか?」という問題は基本的に超難しい問題です。難しい話は全然分からないので読者に理解していただけるように、比較的簡単な「エラトステネスの篩(ふるい)」と「ベルトランの仮説」を利用してみます。

エラトステネスの篩

篩とはあのザルみたいなやつです。なんか上にもの入れてフリフリしたら大きい粒のものは残って細かい粒が下に落ちていくというアレです。この篩に掛ける操作のように、自然数の中から素数でないものをどんどん取り除いていくというのが「エラトステネスの篩」です。
かなり有名な論法なので、ネット上でたくさん解説記事が見つかります。例えば
mathtrain.jp
とか
note.com
とかが最初にヒットしました。他にも、下の著書でも丁寧に解説があったように記憶しています。
www.amazon.co.jp

例としてエラトステネスの篩を使って30以下の素数を全部求めてみましょう。まず 2 以上 30 以下の整数のリストを用意します:

{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}

この中で一番小さい数の 2 を残し、2以外の2の倍数をすべて除去します。

{2, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29}

2 の次にある 3 を残し、3以外の3の倍数をすべて除去します。

{2, 3, 5, 7, 11, 13, 17, 19, 23, 25, 29}

3 の次にある 5 を残し、5以外の5の倍数をすべて除去します。

{2, 3, 5, 7, 11, 13, 17, 19, 23, 29}

残ったものはすべて素数となりました。同様に、n 以下のすべての素数を知るには、{2, 3, ..., n} からスタートして、\sqrt{n} 以下の整数に対して操作が完了した時点で篩掛けが完了します。

ベルトランの仮説

任意の自然数 n に対して、n < p \leqq 2n を満たす素数  p が必ず存在します。「仮説」とついていますが、実際には証明された定理です。
ja.wikipedia.org

悪魔合体

「次の素数」を見つけるためのヒントが揃いました。素数 p が知られているなら、ベルトランの仮説より次の素数2p 以下です。さらにエラトステネスの篩を 2p 以下の整数に適用すれば、かならずすべての素数が得られます。ところでエラトステネスの篩のときに「hoge を残して、hoge 以外の hoge の倍数を除去する」と言ったときの hoge はすべて素数でした。これらを合わせると、次の作戦が可能なことが分かります。

  1. p 以下の素数をすべて記録しておく
  2. \{p+1, p+2, \dots, 2p-1\} から各 p 以下(かつ  \sqrt{2p} 以下)の素数の倍数を取り除く
  3. かならず要素が残り、最小のものが p の次の素数

いざ実装!

素数を記録するものを __init__ 内に用意しておきます。 __next__ ではエラトステネスの篩を実行しましょう。具体的には下のようになります。

class PrimeIter:
    def __init__(self):
        self.p_list = []

    def __iter__(self):
        return self

    def __next__(self):
        if self.p_list == []:
            self.p_list.append(2)
            return 2

        # 前回の素数を呼び出し
        p_prev = self.p_list[-1]
        # 篩(=sieve)を初期化
        sieve = range(p_prev + 1, 2*p_prev)
        # 篩にかける
        for p in self.p_list:
            sieve = [k for k in sieve if k % p != 0]
            if p*p > 2*p_prev:
                break
        # かならず何か残り、最小のものが次の素数
        p_next = sieve[0]
        self.p_list.append(p_next)
        return p_next

これで完成です!心置きなく素数を無限に生成しましょう。

for p in PrimeIter():
    print(p) # 無限に素数を標準出力

おまけ

どこかでストップしてほしい場合は適当な条件で raise StopIteration() しましょう。例えば以下のようにすればよいと思います。

class PrimeIter:
    def __init__(self, p_sup):
        self.p_list = []
        self.p_sup = p_sup

    def __iter__(self):
        return self

    def __next__(self):
        if self.p_list == []:
            if self.p_sup > 2:
                self.p_list.append(2)
                return 2
            else:
                raise StopIteration()

        p_prev = self.p_list[-1]
        sieve = range(p_prev + 1, 2*p_prev)
        for p in self.p_list:
            sieve = [k for k in sieve if k % p != 0]
            if p*p > 2*p_prev:
                break
        p_next = sieve[0]
        if p_next < self.p_sup:
            self.p_list.append(p_next)
            return p_next
        else:
            raise StopIteration()

以下は実行例です。

for p in PrimeIter(100):
    print(p) # 100 以下の素数を標準出力