Alexandre Petit

Untestable code to testable code with Subclass and Override

Temps de lecture : 3 min.

In Enabling safe refactoring with pytest and coverage, we saw how to secure a codebase and how to fill the holes in a test suite. But we cheated in one aspect.

We used tricks to tame IOs. It's fair, but it has limitations:

Here is a way to intercept infrastructure code without relying on global objects.

We will proceed in two steps:

Setup

Retrieve the code: trivia.py

Create an empty play() method in the Game class.

Move the suite of if __name__ == '__main__': into Game.play.

class Game:
    ...
    def play(self):
        not_a_winner = False
        self.add('Chet')
        self.add('Pat')
        self.add('Sue')

        while True:
            self.roll(randrange(5) + 1)

            if randrange(9) == 7:
                not_a_winner = self.wrong_answer()
            else:
                not_a_winner = self.was_correctly_answered()

            if not not_a_winner: break


if __name__ == '__main__':
    game = Game()
    game.play()

We can use the test elaborated in the previous post.

Create test_trivia.py :

import random
import sys
from io import StringIO

from trivia import Game

def test_trivia():
    for i in range(1000):
        seed = i * 100
        random.seed(seed)
        output = StringIO()
        sys.stdout = output
        game = Game()

        game.play()

        approved_path = f"data/test_trivia-approved-{seed}.txt"
        # Generate the approved versions
        # with open(approved_path, "w") as f:
        #     f.write(output.getvalue())
        with open(approved_path, "r") as f:
            expected = f.read()
        assert output.getvalue() == expected

Run pytest. It should pass.

Regroup calls to infrastructure code

Run pytest to verify we start in a working state.

Look for a call to print:

    def add(self, player_name):
        self.players.append(player_name)
        self.places[self.how_many_players] = 0
        self.purses[self.how_many_players] = 0
        self.in_penalty_box[self.how_many_players] = False

        print(player_name + " was added")  # <-- here!
        print("They are player number %s" % len(self.players))

        return True

Extract the argument of the first call to print in a variable message:

    def add(self, player_name):
        ...

        message = player_name + " was added"
        print(message)
        ...

Run pytest to check we didn't break anything.

Extract print(message) in a method. Call it _log :

    def add(self, player_name):
        ...

        message = player_name + " was added"
        self._log(message)
        print("They are player number %s" % len(self.players))

        return True

    def _log(self, message):
        print(message)

Apply the transformation to every call to print so that _log is the only place where print is used.

If you're using PyCharm, it will offer you to apply this transformation to the 19 similar fragments. This automated refactoring messes with the string arguments declared on several physical lines. Run black trivia.py -l 150 before applying the transformation, and it will work fine.

Run pytest to check we didn't break anything.

Inline the message variable.

Run pytest to check we didn't break anything.

OK!

Substitute calls to infrastructure code with Subclass and Override

Now that we grouped calls to print, it will be easy to substitute them.

We want to capture the logs without relying on the standard output.

Create a class TestableGame that inherits Game and override the log method:

class TestableGame(Game):
    __test__ = False  # <-- to prevent pytest trying to collect it

    def __init__(self):
        super().__init__()
        self._messages = []

    def _log(self, message):
        self._messages.append(message + "\n")

    def value(self) -> str:
        return "".join(self._messages)

Now we can test our code without using the standard output:

def test_trivia():
    for i in range(1000):
        seed = i * 100
        random.seed(seed)
        game = InspectableGame()

        game.play()

        approved_path = f"data/test_trivia-approved-{seed}.txt"
        # Generate the approved versions
        # with open(approved_path, "w") as f:
        #     f.write(output.getvalue())
        with open(approved_path, "r") as f:
            expected = f.read()
        assert game.value() == expected

Publié le

Une remarque ? Dis-moi 💙