Multi-process asynchronously with python

What is this article?

I just wrote the code because I wanted to make python faster.

What i wanted to do

My scientific calculations are slow. Both CPU processing and io are slow. By the way, there are various speedups in the street.

In the case of python, it seems that multithread is not fast because of GIL, so Inevitably it will be multiprocess and asynchronous processing. However, even if you try to introduce asyncio for asynchronous processing, ** Coroutines used in asyncio cannot pickle and are incompatible with multi-process ** It looks like ... but ** it seems to be compatible. ** **

However, I'm not sure if it will be faster if both are compatible.

I won't talk about "process-based asynchronous processing" this time. I will talk about "running coroutines in multiple processes".

Reinventing the wheel?

There is a package that multi-processes asynchronously, In fact, it's not enough to bring out the package and increase addiction ... So, I don't know which one is better because it is old or uses multithreading. So, it was easy to implement ... Then I had to write it myself.

Necessary items: asyncio + Pool (). Map

Async in python is asyncio. And in python, multiprocessing is multiprocessing.Pool. Use these two at the same time.

Failure example

However, the following does not work. Coroutines get angry because they can't pickle.

async calc(x):
    return x * 2

async def async_calc(x):
    result = await calc(x)
    return result

with Pool() as pool:
    pool.map(calc, range(3))

If you can't pickle a coroutine, why not throw a function that returns a coroutine?

Now, let's return to the possibility of pickle.

Object Pickle
Coroutine Impossible
lambda expression Impossible
Function by def Yes

In other words, the function that returns a coroutine can be pickled.

Things necessary

As a premise, suppose you want to run the following in parallel.

--read_text: A function that reads text asynchronously --main: Function that actually processes

I wrote it

Here, pass an async function and its arguments to a function called _wrap_async, I'm turning coroutines in it.

from typing import List, Iterable, Callable, Any, cast
from asyncio import get_event_loop, Task
from os import cpu_count
from multiprocessing import Pool

async def read_text(fname: str) -> str:
    with open(fname) as fp:
        result = fp.read()
    return result

async def main(fname: str) -> str:
    txt = Task(read_text(fname))
    result = await txt
    return result

def _wrap_async(func: Callable, *args: Any) -> Any:
    loop = get_event_loop()
    result = loop.run_until_complete(func(*args))
    loop.close()
    return result

def pmap_async(func: Callable, arg: Iterable,
               chunk: int = 2) -> List[Any]:
    with Pool(chunk) as pool:
        result = pool.starmap_async(_wrap_async,
                                    [(func, a) for a in arg]).get()
    return result

result = pmap_async(main, ['test.py'])
print(result)

You can do it like this. So, I end up writing something that seems to be a failure of the map function.

I tried to chain it like node.js because it is a good idea

Suddenly, I am map (func1, map (func2, map (func3, iterator))) I hate writing code like this. Even if I break this up, I don't know how to name the map object in the middle. So I tried to write it like an Array of node.js.

I wrote it, but the code is too long, so I will write it last. Specifically, it can be written as follows. I try to use Generator as much as possible, so I think the overhead is small ...

async def test(x): return x * 2
ChainIter([1, 2]).async_map(test, 2)[0]
>>> 2

Now you can use coroutines in a multi-process chain. ~~ And it is said that it violates the coding standard and it will be Murahachibu. ~~

Long chord

It has become long with a tail fin

from asyncio import new_event_loop, ensure_future, Future
from typing import (Any, Callable, Iterable, cast, Coroutine,
                    Iterator, List, Union)
from itertools import starmap
from multiprocessing import Pool
from doctest import testmod
from functools import reduce, wraps, partial
from logging import Logger, INFO, getLogger, basicConfig
import time


logger = getLogger('ChainIter')
logger.setLevel(INFO)


def future(func: Callable) -> Callable:
    @wraps(func)
    def wrap(*args: Any, **kwargs: Any) -> Future:
        return ensure_future(func(*args, **kwargs))
    return wrap


def run_coroutine(cor: Coroutine) -> Any:
    """
    Just run coroutine.
    """
    loop = new_event_loop()
    result = loop.run_until_complete(cor)
    loop.close()
    return result


def run_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
    """
    Assemble coroutine and run.
    """
    loop = new_event_loop()
    result = loop.run_until_complete(func(*args, **kwargs))
    loop.close()
    return result


class ChainIter:
    """
    Iterator which can used by method chain like Arry of node.js.
    Multi processing and asyncio can run.
    """
    def __init__(self, data: Union[list, Iterable],
                 indexable: bool = False, max_num: int = 0):
        """
        Parameters
        ----------
        data: Iterable
            It need not to be indexable.
        indexable: bool
            If data is indexable, indexable should be True.
        """
        self.data = data
        self.indexable = indexable
        self.num = 0  # Iterator needs number.
        self.max = len(data) if hasattr(data, '__len__') else max_num
        self.bar = True
        self.bar_len = 30

    def map(self, func: Callable, core: int = 1) -> 'ChainIter':
        """
        Chainable map.

        Parameters
        ----------
        func: Callable
            Function to run.
        core: int
            Number of cpu cores.
            If it is larger than 1, multiprocessing based on
            multiprocessing.Pool will be run.
            And so, If func cannot be lambda or coroutine if
            it is larger than 1.
        Returns
        ---------
        ChainIter with result

        >>> ChainIter([5, 6]).map(lambda x: x * 2).get()
        [10, 12]
        """
        logger.info(' '.join(('Running', str(func.__name__))))
        if (core == 1):
            return ChainIter(map(func, self.data), False, self.max)
        with Pool(core) as pool:
            result = pool.map_async(func, self.data).get()
        return ChainIter(result, True, self.max)

    def starmap(self, func: Callable, core: int = 1) -> 'ChainIter':
        """
        Chainable starmap.
        In this case, ChainIter.data must be Iterator of iterable objects.

        Parameters
        ----------
        func: Callable
            Function to run.
        core: int
            Number of cpu cores.
            If it is larger than 1, multiprocessing based on
            multiprocessing.Pool will be run.
            And so, If func cannot be lambda or coroutine if
            it is larger than 1.
        Returns
        ---------
        ChainIter with result
        >>> def multest2(x, y): return x * y
        >>> ChainIter([5, 6]).zip([2, 3]).starmap(multest2).get()
        [10, 18]
        """
        logger.info(' '.join(('Running', str(func.__name__))))
        if core == 1:
            return ChainIter(starmap(func, self.data), False, self.max)
        with Pool(core) as pool:
            result = pool.starmap_async(func, self.data).get()
        return ChainIter(result, True, self.max)

    def filter(self, func: Callable) -> 'ChainIter':
        """
        Simple filter function.

        Parameters
        ----------
        func: Callable
        """
        logger.info(' '.join(('Running', str(func.__name__))))
        return ChainIter(filter(func, self.data), False, 0)

    def async_map(self, func: Callable, chunk: int = 1) -> 'ChainIter':
        """
        Chainable map of coroutine, for example, async def function.

        Parameters
        ----------
        func: Callable
            Function to run.
        core: int
            Number of cpu cores.
            If it is larger than 1, multiprocessing based on
            multiprocessing.Pool will be run.
            And so, If func cannot be lambda or coroutine if
            it is larger than 1.
        Returns
        ---------
        ChainIter with result
        """
        logger.info(' '.join(('Running', str(func.__name__))))
        if chunk == 1:
            return ChainIter(starmap(run_async,
                                     ((func, a) for a in self.data)),
                             False, self.max)
        with Pool(chunk) as pool:
            result = pool.starmap_async(run_async,
                                        ((func, a) for a in self.data)).get()
        return ChainIter(result, True, self.max)

    def has_index(self) -> bool:
        return True if self.indexable else hasattr(self.data, '__getitem__')

    def __getitem__(self, num: int) -> Any:
        if self.has_index():
            return cast(list, self.data)[num]
        self.data = tuple(self.data)
        return self.data[num]

    def reduce(self, func: Callable) -> Any:
        """
        Simple reduce function.

        Parameters
        ----------
        func: Callable

        Returns
        ----------
        Result of reduce.
        """
        logger.info(' '.join(('Running', str(func.__name__))))
        return reduce(func, self.data)

    def get(self, kind: type = list) -> Any:
        """
        Get data as list.

        Parameters
        ----------
        kind: Callable
            If you want to convert to object which is not list,
            you can set it. For example, tuple, dqueue, and so on.
        """
        return kind(self.data)

    def zip(self, *args: Iterable) -> 'ChainIter':
        """
        Simple chainable zip function.
        Parameters
        ----------
        *args: Iterators to zip.

        Returns
        ----------
        Result of func(*ChainIter, *args, **kwargs)
        """
        return ChainIter(zip(self.data, *args), False, 0)

    def __iter__(self) -> 'ChainIter':
        self.calc()
        self.max = len(cast(list, self.data))
        return self

    def __next__(self) -> Any:
        if self.bar:
            start_time = current_time = time.time()
            bar_str = '\r{percent}%[{bar}{arrow}{space}]{div}'
            cycle_token = ('-', '\\', '|', '/')
            cycle_str = '\r[{cycle}]'
            stat_str = ' | {epoch_time:.2g}sec/epoch | Speed: {speed:.2g}/sec'
            progress = bar_str + stat_str
            cycle = cycle_str + stat_str

            prev_time = current_time
            current_time = time.time()
            epoch_time = current_time - prev_time
            if self.max != 0:
                bar_num = int((self.num + 1) / self.max * self.bar_len)
                print(progress.format(
                    percent=int(100 * (self.num + 1) / self.max),
                    bar='=' * bar_num,
                    arrow='>',
                    space=' ' * (self.bar_len - bar_num),
                    div=' ' + str(self.num + 1) + '/' + str(self.max),
                    epoch_time=round(epoch_time, 3),
                    speed=round(1 / epoch_time, 3)
                    ), end='')
            else:
                print(cycle.format(
                    cycle=cycle_token[self.num % 4],
                    div=' ' + str(self.num + 1) + '/' + str(self.max),
                    epoch_time=round(epoch_time, 3),
                    speed=round(1 / epoch_time, 3)
                    ), end='')
        self.num += 1
        if self.num == self.max:
            if self.bar:
                print('\nComplete in {sec} sec!'.format(
                    sec=round(time.time()-start_time, 3)))
            raise StopIteration
        return self.__getitem__(self.num - 1)

    def __reversed__(self) -> Iterable:
        if hasattr(self.data, '__reversed__'):
            return cast(list, self.data).__reversed__()
        raise IndexError('Not reversible')

    def __setitem__(self, key: Any, item: Any) -> None:
        if hasattr(self.data, '__setitem__'):
            cast(list, self.data)[key] = item
        raise IndexError('Item cannot be set.')

    def arg(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
        """
        Use ChainIter object as argument.
        It is same as func(*ChainIter, *args, **kwargs)

        Parameters
        ----------
        func: Callable

        Returns
        ----------
        Result of func(*ChainIter, *args, **kwargs)
        >>> ChainIter([5, 6]).arg(sum)
        11
        """
        return func(tuple(self.data), *args, **kwargs)

    def stararg(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
        """
        Use ChainIter object as argument.
        It is same as func(*tuple(ChainIter), *args, **kwargs)

        Parameters
        ----------
        func: Callable

        Returns
        ----------
        ChainIter object
        >>> ChainIter([5, 6]).stararg(lambda x, y: x * y)
        30
        """
        return func(*tuple(self.data), *args, **kwargs)

    def calc(self) -> 'ChainIter':
        """
        ChainIter.data may be list, map, filter and so on.
        This method translate it to list.
        If you do not run parallel, it can print progress bar if you want.

        Parameters
        ----------
        max_len: int = 0
            Length of data.
            If 0, it print pivot bar.

        Returns
        ----------
        ChainIter object
        """
        if self.bar:
            res = []
            start_time = current_time = time.time()
            bar_str = '\r{percent}%[{bar}{arrow}{space}]{div}'
            cycle_token = ('-', '\\', '|', '/')
            cycle_str = '\r[{cycle}]'
            stat_str = ' | {epoch_time:.2g}sec/epoch | Speed: {speed:.2g}/sec'
            progress = bar_str + stat_str
            cycle = cycle_str + stat_str
            for n, v in enumerate(self.data):
                res.append(v)
                prev_time = current_time
                current_time = time.time()
                epoch_time = current_time - prev_time
                if self.max != 0:
                    bar_num = int((n + 1) / self.max * self.bar_len)
                    print(progress.format(
                        percent=int(100 * (n + 1) / self.max),
                        bar='=' * bar_num,
                        arrow='>',
                        space=' ' * (self.bar_len - bar_num),
                        div=' ' + str(n + 1) + '/' + str(self.max),
                        epoch_time=round(epoch_time, 3),
                        speed=round(1 / epoch_time, 3)
                        ), end='')
                else:
                    print(cycle.format(
                        cycle=cycle_token[n % 4],
                        div=' ' + str(n + 1) + '/' + str(self.max),
                        epoch_time=round(epoch_time, 3),
                        speed=round(1 / epoch_time, 3)
                        ), end='')
            print('\nComplete in {sec} sec!'.format(
                sec=round(time.time()-start_time, 3)))
            self.data = res
            return self
        self.data = list(self.data)
        return self

    def __len__(self) -> int:
        self.calc()
        return len(cast(list, self.data))

    def __repr__(self) -> str:
        return 'ChainIter[{}]'.format(str(self.data))

    def __str__(self) -> str:
        return 'ChainIter[{}]'.format(str(self.data))

    def print(self) -> 'ChainIter':
        print(self.data)
        return self


if __name__ == '__main__':
    testmod()

Recommended Posts

Multi-process asynchronously with python
Debug python multiprocess program with VSCode
FizzBuzz with Python3
Scraping with Python
Statistics with python
Scraping with Python
Twilio with Python
Integrate with Python
Play with 2016-Python
AES256 with python
python starts with ()
with syntax (Python)
Bingo with python
Zundokokiyoshi with python
[Python] About multi-process
Excel with Python
Microcomputer with Python
Cast with python
Serial communication with Python
Django 1.11 started with Python3.6
Primality test with Python
Python with eclipse + PyDev.
Socket communication with Python
Data analysis with python 2
Scraping with Python (preparation)
Try scraping with Python.
Learning Python with ChemTHEATER 03
"Object-oriented" learning with python
Run Python with VBA
Handling yaml with python
Solve AtCoder 167 with python
Serial communication with python
[Python] Use JSON with Python
Learning Python with ChemTHEATER 05-1
Learn Python with ChemTHEATER
Run prepDE.py with python3
1.1 Getting Started with Python
Collecting tweets with Python
Binarization with OpenCV / Python
3. 3. AI programming with Python
Kernel Method with Python
Non-blocking with Python + uWSGI
Scraping with Python + PhantomJS
Posting tweets with python
Drive WebDriver with python
Use mecab with Python3
[Python] Redirect with CGIHTTPServer
Voice analysis with python
Think yaml with python
Getting Started with Python
Use DynamoDB with Python
Zundko getter with python
Handle Excel with python
Ohm's Law with Python
Primality test with python
Solve Sudoku with Python
Python starting with Windows 7
Heatmap with Python + matplotlib
Python programming with Atom
Learning Python with ChemTHEATER 02
Use Python 3.8 with Anaconda