コンテンツにスキップ

12月27日(金)正午12時入金分まで年内出荷いたします。それを過ぎると1月6日(月)以降の出荷となりますのでご注意ください。
また年始は数日の間当日出荷ができない可能性がございます。詳細につきましてはこちらの記事をご確認ください。

12月27日(金)正午12時入金分まで年内出荷いたします。それを過ぎると1月6日(月)以降の出荷となりますのでご注意ください。また年始は数日の間当日出荷ができない可能性がございます。

詳細につきましてはこちらの記事をご確認ください。

Pythonで動的にモジュールをインポートして実行するようにするけど、モジュール側の定義もチェックしたい話

Pythonで動的にモジュールをインポートして実行するようにするけど、モジュール側の定義もチェックしたい話

こんにちは。
ECシステム開発チームのいまづです。
みなさん AWS Lambda のPython環境アップデートしてますか?

僕らもしていますが、Lambda関数が多くて大変です。ついでにライブラリのアップデートしたりするのですが、CDKをTypeScriptで書いていることもあって、nodeのライブラリのアップデートも一緒にするので余計大変です。
そんな数あるLambda関数ですが、なんぼなんでも多すぎやろ、ということで減らせるところは減らす工夫をすることを考えています。

今回整理しようと思ったのは、いろんな集計用Lambda関数を作っているリポジトリです。
このリポジトリ(僕が作りました)では集計用のPythonモジュールごとにLambda関数を追加するようになっていて(僕がやりました)、集計の種類を増やすとLambda関数が増えていくという作りになっていて困ります(僕がやりました)。
これだと後々の(今も)メンテナンスが思いやられるので、パラメータで実行対象のモジュールを指定できるプラグインのような形にして、Lambda関数自体はひとつで済むようにしようと思います。
Python 3.12の標準ライブラリでやってみます。

Lambda関数ひとつで済ませる仕組みを考える

複数ある集計処理を実行するLambda関数をひとつで済ませようとするなら、なんらかの方法で実行対象の集計処理を切り替える必要があります。
単純に考えると、Lambda関数実行時に指定するPayloadに指定するパラメータで条件分岐を書く形でしょうか。

def handler(event, context):
    a_type = event['a_type']
    if a_type == 1:  # 売上集計
        ...

これだと、集計モジュールを追加するたびにこのハンドラの条件分岐を追加していく必要があって若干めんどくさいですね。

現状の実装では、集計の種類ごとにモジュールを作ってあって、それをそれぞれのハンドラ関数から使う形になっています。
集計処理関数のインターフェイスは微妙にそろっていませんがだいたい近いものになっているので、これを揃えて「実行対象のPythonモジュールを指定する」形にすると楽ができそうです。

イメージ

def handler(event, context):
    mod_name_a = event['mod_name']
    mod_a = load_module(mod_name_a)
    # 実行
    data = mod_a.execute(event)
    ...

これなら集計モジュールを新しく追加しても、ハンドラ関数自体は触る必要がないようにできそうです。

単純に「指定したモジュールをロードして実行する」だけの実験

こんな感じでファイルを用意します。

.
├── aggregators
│   ├── __init__.py
│   ├── mod_a.py
│   └── mod_b.py
└── main.py

aggregators の各ファイル
mod_a.py

def execute(params):
    print(f"mod_a executed with params: {params}")

mod_b.py

def execute(params):
    print(f"mod_b executed with params: {params}")

main.py は実行役

def handler(event, context):
    """Lambda関数用のハンドラ関数
    """
    mod_name = event[KEY_MOD_NAME]
    # `aggregators` パッケージからモジュールを動的に読み込む
    mod = importlib.import_module(f"aggregators.{mod_name}")
    mod.execute(event)


def main(mod_name):
    """ローカルでの動作確認用
    """
    event = {
        "mod_name": mod_name,
    }
    handler(event, None)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("mod_name", help="module name.")
    args = parser.parse_args()
    main(args.mod_name)

実行してみます。

$ rye run python main.py mod_a
mod_a executed with params: {'mod_name': 'mod_a'}

$ rye run python main.py mod_b
mod_b executed with params: {'mod_name': 'mod_b'}

$ rye run python main.py mod_c
Traceback (most recent call last):
  File "/home/imazu/sample/main.py", line 43, in <module>
    main(args.mod_name)
  File "/home/imazu/sample/main.py", line 34, in main
    handler(event, None)
  File "/home/imazu/sample/main.py", line 24, in handler
    mod = importlib.import_module(f"aggregators.{mod_name}")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/imazu/.rye/py/cpython@3.12.2/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1324, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'aggregators.mod_c'

はい。存在するファイルを指定したらそれぞれが動いていますし、存在しないファイルなら当然エラーになりますね。

実際はもうちょっと複雑なことがしたくなる

これだけでもLambda関数減らす目的からするとまあいいかな、というところなのですが、入力パラメータのバリデーションエラーのハンドリングだとかは共通にしておきたいとかが出てきます。

実行部分にバリデーション追加するとか

def handler(event, context: dict):
    """Lambda関数用のハンドラ関数
    """
    mod_name = event["mod_name"]
    # `aggregators` パッケージからモジュールを動的に読み込む
    mod = importlib.import_module(f"aggregators.{mod_name}")

    # バリデーション
    try:
        mod.validate(event)
    except exceptions.ValidationError as e:
        logger.error(e, exc_info=True)
        raise e

    # 集計処理実行
    mod.execute(event)

バリデーション用の例外定義(exceptions.py

class ValidationError(Exception):
    pass

各モジュールにも関数を生やす

import exceptions


def validate(params: dict):
    if "year" not in params:
        raise exceptions.ValidationError("year is required.")


def execute(params: dict):
    print(f"mod_a executed with params: {params}")

他にもメインの集計処理実行後にどうにかしたいことがあったりしますね。
aggregators/mod_a.py に下記を追加します。

def post_execute(params: dict):
    # 例えばGoogle Driveにもアップロードするとか
    print(f"mod_a.post_execute() executed with params: {params}")

でも、mod_b では不要かもしれません。

指定するモジュールの書き方ドキュメントが必要

他のメンバーや未来の自分が集計モジュールを追加するときに参考にできるドキュメントを書いておく必要がありますね。
でも文章で書くよりも、モジュールのテンプレートになるようなPythonファイルを用意しておいた方がわかりやすいように思います。

aggretators/module_template.py としてこういうのを置いておくことにします。

"""追加モジュールテンプレート

このファイルに定義されている関数を、追加するモジュールでは定義する
必要があります。
"""
import exceptions


def validate(params: dict):
    """パラメータのバリデーションを行います。

    Lambda関数ハンドラの`event`が`params`に指定されるものとして、
    必要なチェックを行います。
    不正なパラメータがある場合は、`exceptions.ValidationError`を
    raiseしてください。
    """
    raise NotImplementedError


def execute(params: dict):
    """集計処理を実行します。

    `params` には、Lambda関数ハンドラの`event`が指定されます。
    """
    raise NotImplementedError


def post_execute(params: dict):
    """集計処理完了後に別途実行したい処理を記述します。

    例えばGoogle Driveにもアップロードするとか。
    """
    raise NotImplementedError

モジュールに必要な関数が定義されているかを確認するテストがほしい

集計モジュールに複数の関数が必要になってくると、数が増えた場合にはそれらが必要な関数を実装しているのか確認する手段が欲しくなります。
厳密でなくてもいいけど、少なくとも呼び出し時にエラーにならない程度には確認できると安心です。

inspect モジュールを使うと、ロードしたモジュールの関数定義を得られます。これで確認ができそうな気がします。

>>> import importlib
>>> import inspect
>>> mod = importlib.import_module('aggregators.mod_a')
>>> sig = inspect.signature(mod.execute)
>>> sig.parameters 
mappingproxy(OrderedDict({'params': <Parameter "params: dict">}))

テストを書いてみます。

import importlib
import inspect
import unittest


class TestAggregators(unittest.TestCase):
    """集計モジュールに必要な関数が定義されているかを確認するテスト
    """

    def test_modules(self):
        module_names = [
            'aggregators.mod_a',
            'aggregators.mod_b',
        ]
        func_names_and_arg_names = {
            'validate': ['params'],
            'execute': ['params'],
            'post_execute': ['params'],
        }

        for module_name in module_names:
            mod = importlib.import_module(module_name)
            for fname, argnames in func_names_and_arg_names.items():
                self.assertTrue(hasattr(mod, fname), f'{mod.__name__}.{fname}')  # 関数が存在する
                sig = inspect.signature(getattr(mod, fname))
                keys = list(sig.parameters.keys())  # 仮引数名を取得
                self.assertEqual(argnames, keys, f'{argnames}, {keys}')  # 仮引数名が一致する


if __name__ == '__main__':
    unittest.main()

実行してみると、

rye run python tests.py -v
test_modules (__main__.TestAggregators.test_modules) ... FAIL

======================================================================
FAIL: test_modules (__main__.TestAggregators.test_modules)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/imazu/sample/tests.py", line 24, in test_modules
    self.assertTrue(hasattr(mod, fname), f'{mod.__name__}.{fname}')  # 関数が存在する
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: False is not true : aggregators.mod_b.post_execute

----------------------------------------------------------------------
Ran 1 test in 0.001s

FAILED (failures=1)

エラーになりました。
これは post_execute() 関数を aggregatores/mod_b.py には追加しなかったためですが、何もしない関数を定義するというのもイマイチな気がします。

それに、この部分

        module_names = [
            'aggregators.mod_a',
            'aggregators.mod_b',
        ]
        func_names_and_arg_names = {
            'validate': ['params'],
            'execute': ['params'],
            'post_execute': ['params'],
        }

モジュールが増えたり、関数定義を変えたりしたときに、ここを変更するのはめんどくさそうです。

テンプレートに関数定義があるじゃないか

ドキュメント代わりに書いたテンプレートには各モジュールが実装すべき関数の定義があります。
テストにいちいち関数名や仮引数について書くのではなく、「テンプレートと同じである」確認で良いことにすれば、テンプレートの更新漏れも防げて良いのでは?と考えました。

先に書いてみたテストで仮引数名を取得しているところがありますが、

                sig = inspect.signature(getattr(mod, fname))
                keys = list(sig.parameters.keys())  # 仮引数名を取得 

ここの sig.parameters.values() には inspect.Parameter のインスタンスが入っています。
これは、仮引数名や型などの情報を持っています。 __eq__ メソッドの内容をみると、==で比較して良さそうです。

関数名のリストがほしいので dir(mod) としようと思いますが、このままだと __ で始まるような関数もインポートしているモジュールなども取ってきてしまうので、_で始める関数名は無視して inspect.isfunction なものだけにします。

>>> [x for x in dir(mod_a) if not x.startswith('_') and inspect.isfunction(getattr(mod_a, x))]
['execute', 'post_execute', 'validate']

テストのメソッドはこうなりました。

    def test_modules(self):
        # テンプレートモジュールの関数を取得
        mod_tmpl = importlib.import_module('aggregators._module_template')
        funcs_tmpl = {  # 関数名を使いたいので、辞書にしておく
            x: getattr(mod_tmpl, x)
            for x in dir(mod_tmpl)
            if inspect.isfunction(getattr(mod_tmpl, x)) and not x.startswith('_')  # `_`で始まる関数は対象外
        }
        module_names = [
            'aggregators.mod_a',
            'aggregators.mod_b',
        ]

        for module_name in module_names:
            mod_target = importlib.import_module(module_name)
            for fname, func_tmpl in funcs_tmpl.items():
                # テンプレート関数の引数定義
                params_tmpl = inspect.signature(func_tmpl).parameters.values()
                # 対象モジュールにテンプレート関数が存在するか
                self.assertTrue(hasattr(mod_target, fname), f'{mod_target.__name__}.{fname}')  # 関数が存在する
                # 比較対象の関数
                func_target = getattr(mod_target, fname)
                params_mod = inspect.signature(func_target).parameters.values()
                # 引数定義比較
                for pt, pm in zip(params_tmpl, params_mod):
                    self.assertEqual(pt, pm, f'{module_name}.{fname}, {pt.name}')

関数定義のリストはなくせました。
でも、これはやっぱりfailします。なぜなら、

post_execute() 関数を aggregatores/mod_b.py には追加しなかったため

を解消していないからです。

「テストで無視してもいい関数定義」をどう定義する?

パッと思いついたのはテンプレート側の関数に「これはなくてもいいよ」デコレータを書くというものでした。
他にもあるのかもしれませんが、一番手っ取り早そうに思いましたので、これで進めます。
テンプレートのファイルに書いてしまおうと思いますが、上記テストの「_で始まる関数は対象外」の条件に引っ掛けたいので、_始まりの関数名にします。

def _optional(func):
    '''オプション引数を宣言するデコレータ

    このデコレータを付与した関数は、テスト実行時にモジュール側に定義されていないくても
    エラーにならない。
    '''
    func._is_optional = True
    return func

対象のテンプレート関数にデコレータをつけておきます。

@_optional
def post_execute(params: dict):
    """集計処理完了後に別途実行したい処理を記述します。

    例えばGoogle Driveにもアップロードするとか。
    """
    raise NotImplementedError

テストの関数の存在チェックを変更します。

        module_names = [
            'aggregators.mod_a',
            'aggregators.mod_b',
        ]

        for module_name in module_names:
            for fname, func_tmpl in funcs_tmpl.items():
                # テンプレート関数の引数定義
                params_tmpl = inspect.signature(func_tmpl).parameters.values()

                # 対象モジュールにテンプレート関数が存在するか
                if not hasattr(mod_target, fname):
                    # optionalが指定されていればスルーしてOK
                    self.assertTrue(hasattr(func_tmpl, '_is_optional'))
                    continue  # 関数がないので以下のチェックはスキップ

                # 比較対象の関数
                func_target = getattr(mod_target, fname)

これでテストが通るようになりました。

対象モジュールも動的に得るようにしたい

ここまでくると対象モジュールをハードコードしているのもどうにかしたいです。
[pkgutilwalk_packages] (https://docs.python.org/ja/3/library/pkgutil.html#pkgutil.walk_packages) を使って、aggregators のモジュールを得られるように、、、と思いましたが、このままではテンプレートモジュールまで読み込まれてしまいます。
関数名と同じく名前で解決してしまうことにして、ファイル名を _module_template.py に変更します。

最終的にはこのようになりました。

import importlib
import inspect
import pkgutil
import unittest


class TestAggregators(unittest.TestCase):
    """集計モジュールに必要な関数が定義されているかを確認するテスト
    """

    def test_modules(self):
        # テンプレートモジュールの関数を取得
        mod_tmpl = importlib.import_module('aggregators._module_template')
        funcs_tmpl = {  # 関数名を使いたいので、辞書にしておく
            x: getattr(mod_tmpl, x)
            for x in dir(mod_tmpl)
            if inspect.isfunction(getattr(mod_tmpl, x)) and not x.startswith('_')
        }

        # 対象モジュールを取得
        for modinfo in pkgutil.walk_packages(['aggregators'], 'aggregators.'):
            # アンダースコアで始まるモジュールはスキップ
            if modinfo.name.split('.')[-1].startswith('_'):
                continue

            mod_target = importlib.import_module(modinfo.name)
            for fname, func_tmpl in funcs_tmpl.items():
                # テンプレート関数の引数定義
                params_tmpl = inspect.signature(func_tmpl).parameters.values()
                # 対象モジュールにテンプレート関数が存在するか
                if not hasattr(mod_target, fname):
                    # optionalが指定されていなければエラー
                    self.assertTrue(hasattr(func_tmpl, '_is_optional'))
                    continue  # 関数がないので以下のチェックはスキップ

                # 比較対象の関数
                func_target = getattr(mod_target, fname)
                params_mod = inspect.signature(func_target).parameters.values()
                # 引数定義比較
                for pt, pm in zip(params_tmpl, params_mod):
                    self.assertEqual(pt, pm, f'{modinfo.name}.{fname}, {pt.name}')


if __name__ == '__main__':
    unittest.main()

で、どうなったかというと

当初の目論見通り、Lambda関数をひとつにまとめることができそうです。
追加するときにPython以外の部分を考えなくて良くなったところは楽ができそうですね。
新しいモジュールを aggregators/ に追加するだけで済みます。
ただし、実行時に無効なモジュールを指定してしまったり、ということはあり得るでしょうから、利用する側の実装時には注意とテストが必要ですね。
あと、うっかりモジュール名を変えてしまうとかですね。

この記事では例のために単純な関数定義しか使いませんでしたが、実際に使っているものはもうすこし複雑な仮引数のリストを持っていたりします。ですが、おおよそこのやりかたの延長のものを使っています。

こういうやり方もあるよ、と思ったそこのあなた

スイッチサイエンスでは開発者を募集しています。
ぜひカジュアル面談で教えてください。

 

 


 

システムエンジニア募集中です!

興味のある方はぜひカジュアル面談へ!お待ちしています。

前の記事 Hakko is Back at SparkFun!