Python の attrs で validation をかける

Python の attrs が便利すぎて、データ保持用のクラスを作る時なんかは専らコレばっかり使っています。

そんな attrs で入力されるデータに validation をかける方法です。

attrs で入力データをチェック

例えば Pose というクラスに対して、位置を表す position は長さ 3 の配列、向きを表す orientation は長さ 4 の配列(クォータニオン) であることをチェックする場合は以下のように書くことができます。

from attr import attrs, attrib


@attrs
class Pose:
    position = attrib()
    orientation = attrib()

    @position.validator
    def check_position(self, attribute, value):
        if len(value) != 3:
            raise ValueError("position must be a list with size=3")

    @orientation.validator
    def check_orientation(self, attribute, value):
        if len(value) != 4:
            raise ValueError("orientation must be a list with size=4")

正しい値を代入した場合

>>> Pose(position=[0, 1, 2], orientation=[0, 0, 0, 1])
Pose(position=[0, 1, 2], orientation=[0, 0, 0, 1])

不正な値を代入した場合

>>> Pose(position=[0, 1], orientation=[0, 0, 0, 0, 1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<attrs generated init d25cdc317f9cca07dcc37d3210f62862e6055553>", line 5, in __init__
  File "<stdin>", line 8, in check_position
ValueError: position must be a list with size=3

普通の class だと __init__ が肥大化したりしがちですが、これだとスッキリ書けますね!

関連サイト