1
/
5

#20 実践Pytest入門(PyCon JP 2024)

あなたのアプリケーションをレガシーコードにしないための実践Pytest入門

概要

レガシーコードは、テストがなく修正や拡張が難しいコードを指します。このようなコードは技術的負債となり、長期的には開発速度の低下や保守コストの増大につながります。

この問題を解決するために、Pytestを活用した単体テストの実践が効果的です。Pytestは、Pythonのテストフレームワークで、テストコードの作成と実行を容易にします。

実際にどのようにテストを書くことができるか見ていきましょう。

シンプルな関数のテスト

身長と体重からBMIを計算する例

def calc_bmi(*, height: float, weight: float) -> float:
if (height <= 0) or (weight <= 0):
raise ValueError("Height and weight must br greater then 0")
return weight / (height ** 2)

多数の分岐条件を持つケースのテスト

所得額から給与所得控除額を計算する例

import pytest

def calc_exemption_amount(*, income: int) -> int:
if income < 0:
raise ValueError("Income must be positive.")
if income <= 1_625_000:
return 550_000
if 1_625_000 < income <= 1_800_000:
return int(income * 0.4) - 100_000
if 1_800_000 < income <= 3_600_000:
return int(income * 0.3) - 800_000
if 3_600_000 < income <= 6_600_000:
return int(income * 0.2) - 440_000
if 6_600_000 < income <= 8_500_000:
return int(income * 0.1) - 1_100_000
return 1_950_000

class TestCalcExemptionAmount:
@pytest.mark.parametrize(
("income", "expected"),
[
(1_625_000, 550_000),
(1_625_003, 550_001),
(1_800_000, 620_000),
(1_800_001, 1_340_000),
(3_600_000, 1_880_000),
(3_600_001, 1_160_000),
(6_600_000, 1_760_000),
(6_600_001, 1_760_001),
(8_500_000, 1_950_000),
(8_500_001, 1_950_001),
]
)
def test_income_and_exemption(self, income, expected):
assert calc_exemption_amount(income) == expected

OS環境変数に依存するケースのテスト

OS環境変数に設定されているAPIのURLを取得する例

"""tests/env/conftest.py"""
import pytest

@pytest.fixture
def mock_env_api_url(monkeypatch):
monkeypatch.setenv("API_URL", "http://localhost:8080")
"""tests/env/test_env.py"""
import pytest

os.environ["API_URL"] = "https://production.example.com"

def get_api_url() -> str | None:
return os.getenv("API_URL")

class TestGetAPIURL:
def test_get_api_url(self, mock_env_api_url):
assert get_api_url() == "http://localhost:8080"

システム日時に依存するケースのテスト

システム時刻が営業時間中かどうかを取得する例

import pytest
from freezegun import freeze_time
from datetime import datetime, time

def is_in_business() -> bool:
now = datetime.now()
if now.weekday() in (5, 6):
return False
if time(9, 0, 0) <= now.time() <= time(17, 0, 0):
return True
return False

class TestIsInBusiness:
@pytest.mark.parametrize(
("now", "expected"),
[
("2024-09-27 08:59:59", False),
("2024-09-27 09:00:00", True),
("2024-09-27 17:00:00", True),
("2024-09-27 17:00:01", False),
("2024-09-28 12:00:00", False),
("2024-09-29 12:00:00", False),
]
)
def test_is_in_business(self, now, expected):
with freeze_time(now):
assert is_in_business() == expected

ファイル入出力のテスト

ファイルのテキストを読み込み、文字列に含まれる猫を犬に置換してファイル出力する例

import pytest
from pathlib import Path
import re

def cat_to_dog(*, input_path: Path, output_path: Path) -> None:
input_text = input_path.read_text()
output_text = re.sub("猫", "犬", input_text)
output_path.write_text(output_text)

class TestCatToDog:
def test_normal(self, tmp_path):
intput_path = tmp_path / "input.txt"
output_path = tmp_path / "output.txt"
input_path.write_text("吾輩は猫である。名前はまだない。")
cat_to_dog(input_path=input_path, output_path=output_path)
assert output_path.read_text() == "吾輩は犬である。名前はまだない。"

外部APIに依存したケースのテスト

郵便番号から住所を取得する外部APIを使用する例

"""tests/api/conftest.py"""
import re
import pytest
import requests
from dataclasses import dataclass

ResultsType = list[dict[str, str]] | None

@dataclass
class MockResponse:
message: str | None = None
results: ResultsType = None

def raise_for_status(self) -> None:
return None

def json(self) -> dict[str, str | ResultsType]:
return {"message": self.message, "results": self.results}

@pytest.fixture
def mock_response(monkeypatch) -> None:
def mock_get(*args, **kwargs) -> MockResponse:
zipcode = kwargs["params"]["zipcode"]
if zipcode == "0000000":
return MockResponse()
elif re.match("^[0-9]{7}$", zipcode):
return MockResponse(results=[{"address1": "都道府県", "address2": "市区町村", "address3": "番地"}])
else:
return MockResponse(message="郵便番号の桁数や値が不正です")
monkeypatch.setattr(requests, "get", mock_get)
"""tests/api/test_api.py"""
import pytest
import requests

ENDPOINT = "http://zipcloud.ibsnet.co.jp/api/search"

def get_address(*, zipcode: str) -> str | None:
response = requests.get(ENDPOINT, params={"zipcode": zipcode}, timeout=5)
response.raise_for_status()
data = response.json()

if (message := data["message"]) is not None:
raise ValueError(message)
if (results := data["results"]) is None:
return None
return f"{results[0]['address1']} {results[0]['address2']} {results[0]['address3']}"

class TestGetAddress:
@pytest.mark.parametrize(
("zipcode", "expected"),
[
("0000000", None),
("1111111", "都道府県 市区町村 番地")
]
)
def test_get_address(self, mock_response, zipcode, expected):
assert get_address(zipcode=zipcode) == expected

@pytest.mark.parametrize("zipcode", ["1", "12345678", "dummy"])
def test_invalid_zipcode(self, mock_response, zipcode):
with pytest.raises(ValueError) as e:
get_address(zipcode=zipcode)
assert str(e.value) == "郵便番号の桁数や値が不正です"

DB接続を伴うケースのテスト

DBのユーザを追加取得するテスト

"""src/db.py"""
import os
from datetime import date

from sqlalchemy import String, create_engine, insert, select
from sqlalchemy.engine import URL
from sqlalchemy.orm import declarative_base, Mapped, mapped_column, scoped_session, sessionmaker
from sqlalchemy.orm.scoping import scoped_session as scoping_scoped_session

Base = declarative_base()

class User(Base):
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(12), nullable=False)
birthday: Mapped[date] = mapped_column(nullable=False)
def __eq__(self, other) -> bool:
return (self.id == other.id) and (self.name == other.name) and (self.birthday == other.birthday)

DATABASE_CONFIG = {
"drivername": "mysql+pymysql",
"username": os.environ["MYSQL_USER"],
"password": os.environ["MYSQL_PASSWORD"],
"host": os.environ["MYSQL_HOST"],
"port": os.environ["MYSQL_POST"],
"database": os.environ["MYSQL_DATABASE"],
"query": {"charset": "utf8"}}
}
engine = create_engine(URL.create(**DATABASE_CONFIG), echo=False)
Base.metadata.create_all(engine)
Session = scoped_session(sessionmaker(engine))

def get_user(db_session: scoping_scoped_session, user_id: int) -> User | None:
stmt = select(User).where(User.id == user_id)
return db_session.scalar(stmt)

def add_user(db_session: scoping_scoped_session, user: User) -> int:
with db_session() as session:
session.add(user)
session.commit()
return user.id
"""tests/api/conftest.py"""
import os
from operator import itemgetter
import pytest
from pytest_mysql import factories
from sqlalchemy import create_engine
from sqlalchemy.engine import URL
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
from db import Base

host, port, user, passwd = itemgetter("MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD")(os.environ)
mysql_noproc = factories.mysql_noproc(host=host, port=port, user=user)
mysql_fixture = factories.mysql("mysql_noproc", password=password, dbname="test")

@pytest.fixture
def test_session(mysql_fixture):
url = URL.create(drivername="mysql+pymysql", username=user, password=passwd, host=host, port=port, database="test", query={"charset": "utf8"})
engine = create_engine(url, echo=False, poolclass=NullPool)
Base.metadata.create_all(engine)
session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))
try:
yield session
except Exception:
session.rollback()
else:
session.commit()
finally:
session.close()
Base.metadata.drop_all(engine)
"""tests/api/user.py"""
from datetime import date
from db import User, get_user, add_user

class TestUser:
def test_get_user(self, test_session):
user = User(name="sato", birthday=date(1999, 12, 31))
with test_session() as session:
session.add(user)
session.commit()
user_id = user.id
user_ = get_user(db_session=test_session, user_id=user_id)
assert user == user_

def test_no_user(self, test_session):
assert get_user(db_session=test_session, user_id=1) is None

def test_add_user(self, test_session):
user = User(name="sato", birthday=date(1999, 12, 31))
user_id = add_user(db_session=test_session, user=user)
assert user == test_session.scalar(select(User).where(User.id == user_id))

def test_duplicate_pk(self, test_session):
try:
for _ in range(2):
add_user(db_session=test_session, user=User(id=1, name="sato", birthday=date(1999, 12, 31)))
except sqlalchemy.exc.IntegrityError as e:
assert isinstance(e.orig, pymysql.err.IntegrityError)
assert e.orig.args[0] == 1062


まとめ

これらのパターンを理解し、適切にテストを実装することで、コードの品質を維持しやすくなります。
さらに、テスト駆動開発(TDD)を採用することで、テストとリファクタリングを同時に行うことができます。

TDDのサイクルは以下のようにすると良いです。

  • 失敗するテストを書く(レッド)
  • テストを通すコードを書く(グリーン)
  • コードをリファクタリングする

このアプローチにより、テストが容易なコードが自然と生まれ、結果的に開発スピードの向上につながります。Pytestを活用し、適切なテストを書くことで、レガシーコード化を防ぎ、長期的に保守性の高いアプリケーションを開発していきましょう。

このストーリーが気になったら、遊びに来てみませんか?
業績好調のため増員募集|経験豊富な代表直下で成長できるIT法人営業
ITEEK株式会社では一緒に働く仲間を募集しています

同じタグの記事

今週のランキング

上川 諒也さんにいいねを伝えよう
上川 諒也さんや会社があなたに興味を持つかも