I tried to implement PLSA in Python 2


Last time has been modified. The contents are as stated below.

--Reduction of memory usage

--Introduced some error handling

Speeding up will come later.

Memory usage reduction

Cause and policy

First, the previous implementation is very memory intensive. The cause is that P (z | x, y) is calculated in the E step of the EM algorithm. Since we have simply created a 3D array, if we do not create it, the memory usage will be considerably reduced.

Specifically, let's look at the update formula for P (x | z) in M step. Since the denominator is just normalized, we only consider the numerator.

P\left( x | z \right)Molecule= \sum_{y} N_{x, y} P \left( z | x, y \right)

Substitute the update formula for P (z | x, y) in step E into this.

P\left( x | z \right)Molecule= \sum_{y} N_{x, y} \frac{P\left( z \right)P\left( x | z \right)P\left( y | z \right)}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)} \tag{1}

I'm going to implement this expression, but I don't want to spin the for statement around, so I use numpy's einsum.


The einsum function is a reduction of Einstein. It's hard to understand, so here's one example.

P(x,y) = \sum_{z}P(z)P(x|z)P(y|z)

When you implement

Pxy = numpy.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)

It will be.

Equation (1) is implemented using this einsum function, but it is difficult to implement as it is, so the equation is transformed as follows.

P\left( x | z \right)Molecule= \sum_{y} \frac{N_{x, y}}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)} P\left( z \right)P\left( x | z \right)P\left( y | z \right)

If you implement this, it will look like this.

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)
Px_z = numpy.einsum('ij,k,ki,kj->ki', tmp, Pz, Px_z, Py_z)

How much memory usage has been reduced by this will be described later.

Error handling

I implemented it using the einsum function,

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)

Consider error handling when divided by 0 in. The fact that this denominator is 0 means that for some x and y,

\sum_{z}P(z)P(x|z)P(y|z) = 0

So, no negative value comes out, so for some x and y and all z,

P(z)P(x|z)P(y|z) = 0

Is established. In other words, the E step of the EM algorithm is as follows.

P(z|x,y) & = & \frac{P\left( z \right)P\left( x | z \right)P\left( y | z \right)}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)}
& = & 0

Therefore, the element divided by 0 is set to 0.

Where in numpy

1 / 0 = inf
0 / 0 = nan

So, using numpy.isinf and numpy.isnan respectively

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)
tmp[numpy.isinf(tmp)] = 0
tmp[numpy.isnan(tmp)] = 0

Px_z = numpy.einsum('ij,k,ki,kj->ki', tmp, Pz, Px_z, Py_z)

It will be.


In summary, the overall implementation is as follows.


import numpy as np

class PLSA(object):
    def __init__(self, N, Z):
        self.N = N
        self.X = N.shape[0]
        self.Y = N.shape[1]
        self.Z = Z

        # P(z)
        self.Pz = np.random.rand(self.Z)
        # P(x|z)
        self.Px_z = np.random.rand(self.Z, self.X)
        # P(y|z)
        self.Py_z = np.random.rand(self.Z, self.Y)

        self.Pz /= np.sum(self.Pz)
        self.Px_z /= np.sum(self.Px_z, axis=1)[:, None]
        self.Py_z /= np.sum(self.Py_z, axis=1)[:, None]

    def train(self, k=200, t=1.0e-7):
Repeat steps E and M until the log-likelihood converges
        prev_llh = 100000
        for i in xrange(k):
            llh = self.llh()

            if abs((llh - prev_llh) / prev_llh) < t:

            prev_llh = llh

    def em_algorithm(self):
EM algorithm
        P(z), P(x|z), P(y|z)Update
        tmp = self.N / np.einsum('k,ki,kj->ij', self.Pz, self.Px_z, self.Py_z)
        tmp[np.isnan(tmp)] = 0
        tmp[np.isinf(tmp)] = 0

        Pz = np.einsum('ij,k,ki,kj->k', tmp, self.Pz, self.Px_z, self.Py_z)
        Px_z = np.einsum('ij,k,ki,kj->ki', tmp, self.Pz, self.Px_z, self.Py_z)
        Py_z = np.einsum('ij,k,ki,kj->kj', tmp, self.Pz, self.Px_z, self.Py_z)

        self.Pz = Pz / np.sum(Pz)
        self.Px_z = Px_z / np.sum(Px_z, axis=1)[:, None]
        self.Py_z = Py_z / np.sum(Py_z, axis=1)[:, None]

    def llh(self):
Log likelihood
        Pxy = np.einsum('k,ki,kj->ij', self.Pz, self.Px_z, self.Py_z)
        Pxy /= np.sum(Pxy)
        lPxy = np.log(Pxy)
        lPxy[np.isinf(lPxy)] = -1000

        return np.sum(self.N * lPxy)

Underflow may occur when calculating the log-likelihood, resulting in log (0) = -inf. The minimum value of a double precision floating point number is about 4.94e-324, so log (4.94e-324) = -744.4 or less, so -1000 is roughly entered.


Use memory_profiler to measure how much memory usage has been reduced. I measured it with the following script.


import numpy as np
from memory_profiler import profile

X = 1000
Y = 1000
Z = 10

def main():
    from plsa import PLSA
    plsa = PLSA(np.random.rand(X, Y), Z)
    llh = plsa.llh()

if __name__ == '__main__':

In the case of X = 1000, Y = 1000, Z = 10, in the previous implementation,

$ python profile_memory_element_wise_product.py 
Filename: profile_memory_element_wise_product.py

Line #    Mem usage    Increment   Line Contents
    10     15.9 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.9 MiB      0.0 MiB       from plsa_element_wise_product import PLSA
    13     23.9 MiB      8.0 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14    108.0 MiB     84.1 MiB       plsa.e_step()
    15    184.5 MiB     76.5 MiB       plsa.m_step()
    16    199.8 MiB     15.3 MiB       llh = plsa.llh()

In this implementation,

$ python profile_memory_einsum.py 
Filename: profile_memory_einsum.py

Line #    Mem usage    Increment   Line Contents
    10     15.8 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.8 MiB      0.0 MiB       from plsa_einsum import PLSA
    13     23.7 MiB      7.9 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14     40.7 MiB     16.9 MiB       plsa.em_algorithm()
    15     48.4 MiB      7.8 MiB       llh = plsa.llh()

In the case of X = 5000, Y = 5000, Z = 10, in the previous implementation,

$ python profile_memory_element_wise_product.py 
Filename: profile_memory_element_wise_product.py

Line #    Mem usage    Increment   Line Contents
    10     15.9 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.9 MiB      0.0 MiB       from plsa_element_wise_product import PLSA
    13    207.6 MiB    191.7 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14   2115.4 MiB   1907.8 MiB       plsa.e_step()
    15   2115.5 MiB      0.1 MiB       plsa.m_step()
    16   2115.5 MiB      0.0 MiB       llh = plsa.llh()

In this implementation,

$ python profile_memory_einsum.py 
Filename: profile_memory_einsum.py

Line #    Mem usage    Increment   Line Contents
    10     15.7 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.7 MiB      0.0 MiB       from plsa_einsum import PLSA
    13    207.5 MiB    191.7 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14    233.0 MiB     25.6 MiB       plsa.em_algorithm()
    15    233.1 MiB      0.0 MiB       llh = plsa.llh()

In summary, the total memory usage is

Implementation X=1000,Y=1000,Z=10 X=5000,Y=5000,Z=10
Last implementation 199.8 MiB 2115.5 MiB
This implementation 48.4 MiB 233.1 MiB

However, if we limit this to the EM algorithm and the calculation part of the log-likelihood,

Implementation X=1000,Y=1000,Z=10 X=5000,Y=5000,Z=10
Last implementation 175.9 MiB 1907.9 MiB
This implementation 24.7 MiB 25.6 MiB

You can see that the amount of memory used in this implementation has hardly increased. With this, even if the number of data increases, I am not afraid of memory.

Calculation speed

The einsum function is convenient, but it takes a long time to calculate all three or four at once like this time. For my MacBook, it took about 8.7 seconds to calculate the log-likelihood and the EM algorithm once at X = 5000, Y = 5000, Z = 10. This isn't realistic and needs improvement, but it's coming back later.


――It's longer than I expected.

――The einsum function is also useful.

Recommended Posts

I tried to implement PLSA in Python
I tried to implement PLSA in Python 2
I tried to implement permutation in Python
I tried to implement ADALINE in Python
I tried to implement PPO in Python
I tried to implement TOPIC MODEL in Python
I tried to implement selection sort in python
I tried to implement a pseudo pachislot in Python
I tried to implement Dragon Quest poker in Python
I tried to implement GA (genetic algorithm) in Python
I tried to implement a one-dimensional cellular automaton in Python
I tried to implement the mail sending function in Python
I tried to implement blackjack of card game in Python
I tried to implement PCANet
I tried to implement a misunderstood prisoner's dilemma game in Python
I tried to implement StarGAN (1)
I tried to implement Bayesian linear regression by Gibbs sampling in python
I tried to implement a card game of playing cards in Python
I tried to graph the packages installed in Python
I want to easily implement a timeout in python
I tried to implement Minesweeper on terminal with python
I tried to implement an artificial perceptron with python
I tried to summarize how to use pandas in python
I tried to implement Deep VQE
I tried to touch Python (installation)
I tried to implement adversarial validation
I tried to implement hierarchical clustering
I tried to implement Realness GAN
I tried Line notification in Python
I tried to implement merge sort in Python with as few lines as possible
I tried to implement what seems to be a Windows snipping tool in Python
I tried to create API list.csv in Python from swagger.yaml
I tried "How to get a method decorated in Python"
I tried to make a stopwatch using tkinter in python
I tried to summarize Python exception handling
I tried to implement Autoencoder with TensorFlow
Python3 standard input I tried to summarize
I tried using Bayesian Optimization in Python
I wanted to solve ABC159 in Python
I tried to implement CVAE with PyTorch
[Python] I tried to calculate TF-IDF steadily
I tried to touch Python (basic syntax)
[Python] I tried to implement stable sorting, so make a note
Implement Enigma in python
I tried Python> autopep8
Implement recommendations in Python
Implement XENO in python
I tried to debug.
I tried to paste
Implement sum in Python
I tried Python> decorator
Implement Traceroute in Python 3
I tried to implement reading Dataset with PyTorch
I want to do Dunnett's test in Python
Try to implement Oni Maitsuji Miserable in python
How to implement Discord Slash Command in Python
I was able to recurse in Python: lambda
I want to create a window in Python
I tried playing a typing game in Python
How to implement shared memory in Python (mmap.mmap)
I tried the least squares method in Python