Alexandre Petit

Untestable code to testable code with Invert Dependency

Temps de lecture : 3 min.

Dependency Injection is a great tool to:

Here is how to achieve it in code that is tightly coupled.

We can proceed in 4 steps:

  1. Setup the project
  2. Group calls to infrastructure
  3. Move to delegate
  4. Inject dependency

Setup

Apply instructions in Setup

Regroup calls to infrastructure code

Apply instructions in Regroup calls to infrastructure code

Move to delegate

Let's do some wishful thinking.

I would like to have an instance of a ConsoleLogger and call logger.log(message)

class Game:
    ...

    def _log(self, message: str):
        # what we would like
        logger = ConsoleLogger()
        logger.log(message)
        # what we actually have
        print(message)

Well, let's create it.

class ConsoleLogger:
    def log(self, message: str):
        pass

It does nothing for the moment.

Run pytest to check we didn't break anything.

Now, move the infrastructure code print(message) into the delegate.

class Game:
    ...

    def _log(self, message: str):
        logger = ConsoleLogger()
        logger.log(message)

    ...

class ConsoleLogger:
    def log(self, message: str):
        print(message)

Run pytest to check we didn't break anything.

Invert dependency

We just took the infrastructure code out of our Game class.

For the sake of our test, we would like to pass an alternative implementation of the logger. For example : an in memory implementation that doesn't make any call to infrastructure code like print.

To do so, we start by inverting the dependency.

Extract the interface Logger from ConsoleLogger:

class Logger(metaclass=ABCMeta):
    @abstractmethod
    def log(self, message):
        raise NotImplementedError


class ConsoleLogger(Logger):
    def log(self, message: str):
        print(message)

Run pytest to check we didn't break anything.

Write a test implementation (it's a good candidate for TDD):

class InMemoryLogger(Logger):
    def __init__(self):
        self._messages = []

    def log(self, message: str):
        self._messages.append(message + "\n")

    def getvalue(self) -> str:
        return "".join(self._messages)

Now the question is, "How do we use this implementation in our test?".

Part of the answer is that it's not the responsibility of the game to instantiate the logger. And even less a responsibility of one of its methods. This can be fixed in a few steps.

Introduce Field in the constructor:

class Game:
    def __init__(self):
        self._logger = ConsoleLogger()
        ...

    ...

    def _log(self, message):
        self._logger.log(message)

Introduce Parameter:

class Game:
    def __init__(self, logger: Logger = ConsoleLogger()):
        self._logger = logger
        ...

Go back to test_trivia.py.

Instantiate an InMemoryLogger and inject it into the game.

In the assert, use logger.getvalue() instead of the StringIO.

def test_trivia():
    for i in range(1000):
        seed = i * 100
        random.seed(seed)
        # output = StringIO()                   -- delete this code
        # sys.stdout = output                   -- delete this code
        logger = InMemoryLogger()
        game = Game(logger)

        game.play()

        approved_path = f"data/test_trivia-approved-{seed}.txt"
        # Generate the approved versions
        # with open(approved_path, "w") as f:
        #     f.write(output.getvalue())
        with open(approved_path, "r") as f:
            expected = f.read()

        # assert output.getvalue() == expected  -- delete this code
        assert logger.getvalue() == expected

Now, we are able to choose which implementation to use in our game.

Convenient, isn't it?

Publié le

Une remarque ? Dis-moi 💙