Friday, July 31, 2020

How to capture/redirect output in Python's print

The print function in Python v3 is one of the more sophisticated changes made from v2. You can tell it where to print by using the parameter 'file'. It can either print to the screen as usual or it can print into a text file for example. Here's how you make it print into a text file:

with open('x.txt', 'w', encoding='utf-8') as f:
    print('hello', file=f)

The 'file' parameter accepts a stream object and the print function basically just passes the string you give it to this stream object's 'write' method. In fact in the above example you could have written to the file without using print as follows:

with open('x.txt', 'w', encoding='utf-8') as f:
    f.write('hello')
    f.write('\n')
    f.flush()

So question is, what is the stream used by print when it prints to the screen? The print method's signature is defined as follows:

print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)

So the answer is sys.stdout, which is a variable that points to the object sys.__stdout__. In fact you can make your own print function by using the sys.__stdout__ object as follows:

import sys

def my_print(line):
    sys.__stdout__.write(line)
    sys.__stdout__.write('\n')
    sys.__stdout__.flush()

my_print('hello')

Note that the output will appear in the command line terminal. Now we saw how the default file parameter of the print function is whatever sys.stdout is pointing to. This means that we can highjack this variable with our own stream so that default prints get redirected to our code instead of the screen. Here's how to make your own stream:

class MyStream(object):

    def write(self, text):
        pass

    def flush(self):
        pass

Just replace 'pass' in the write method with whatever you want and that is what will happen with the print string. Now normally you would do something like this:

print('hello', file=MyStream())

but if you're trying to capture the outputs of some library, like sklearn, then you don't have control over the library's print functions. Instead you can do this:

import sys

tmp = sys.stdout
sys.stdout = MyStream()

print('hello')

sys.stdout = tmp

First you hijack the stdout variable with your own stream and then after the prints you restore the stdout variable with the original stream. Now, you need to be careful with what happens inside your stream's write method because if there is another print in there then it will also call your stream's write method which will result in an infinite recursion loop. What you can do is temporarily restore the original stream whilst inside your stream, like this:

import sys

class MyStream(object):

    def write(self, text):
        tmp = sys.stdout
        sys.stdout = sys.__stdout__

        print('my', text)

        sys.stdout = tmp

    def flush(self):
        pass


tmp = sys.stdout
sys.stdout = MyStream()

print('hello')

sys.stdout = tmp

Now every time you print something, you'll get 'my' added to its front. Unfortunately the print function sends the '\n' at the end of the line in a separate call to the stream's write method, which means that you actually get two calls per print:

my hello
my


A simple solution is to just put an 'if' statement that checks if the text received is just a '\n' and ignore it if so:

import sys

class MyStream(object):

    def write(self, text):
        if text == '\n':
            return

        tmp = sys.stdout
        sys.stdout = sys.__stdout__

        print('my', text)

        sys.stdout = tmp

    def flush(self):
        pass


tmp = sys.stdout
sys.stdout = MyStream()

print('hello')

sys.stdout = tmp