This is an article that tries to speed up the scratch implementation of MCMC by making it Multiprocessing. In the other day's article, "[Statistics] Markov Chain Monte Carlo (MCMC) sampling with animation.", we implemented chain. I didn't have it, so I only had one chain, but I tried to sample this with multiple chains and execute it as a multi-process. Since MCMC is independent for each chain, it is OK to simply separate the process, so it was easy to speed up.
⇒Since it has 2 cores, it can only effectively accelerate up to 2 processes ...
The code is posted on GitHub. https://github.com/matsuken92/Qiita_Contents/blob/master/multiprocessing/parallel_MCMC.ipynb
First of all, I would like to see the movement of MultiProcessing with a simple process.
First, import the library. We use a class called Pool that manages multiple worker processes.
from multiprocessing import Pool
For the time being, it seems to be a heavy process, so let's target a process that loops a lot. It only adds, but it takes a few seconds to turn it about 100000000 times.
def test_calc(num):
"""Heavy processing"""
_sum = 0
for i in xrange(num):
_sum += i
return _sum
Let's measure the speed when this process is executed twice in order.
#Measure the time when it is executed twice sequentially
start = time.time()
_sum = 0
for _ in xrange(2):
_sum += test_calc(100000000)
end = time.time()
print _sum
print "time: {}".format(end-start)
It took less than 12 seconds.
out
9999999900000000
time: 11.6906960011
Next, perform the same process in parallel for two processes and measure.
#Measure the time when executed in 2 processes
n_worker = 2
pool = Pool(processes=n_worker)
#Argument list to pass to the function executed by the two processes
args = [100000000] * n_worker
start = time.time() #measurement
result = pool.map(test_calc, args)
end = time.time() #measurement
print np.sum(result)
print "time: {}".format(end-start)
pool.close()
It's a little over 6 seconds, so it's almost half the time. I was able to speed up by 2 processes: laughing:
out
9999999900000000
time: 6.28346395493
Now, let's apply this to processing each chain of MCMC sampling in parallel. As always, the first thing to do is import the library.
import numpy as np
import numpy.random as rd
import scipy.stats as st
import copy, time, os
from datetime import datetime as dt
from multiprocessing import Pool
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", palette="muted", color_codes=True)
The function P (・)
is the target posterior distribution kernel. Here we are using a kernel with a two-dimensional normal distribution.
#Probability function minus normalization constants
def P(x1, x2, b):
assert np.abs(b) < 1
return np.exp(-0.5*(x1**2 - 2*b*x1*x2 + x2**2))
The parameters of the proposed distribution are defined as global.
It also defines the now (・)
function for time measurement. A function that displays a string of the current time.
# global parameters
b = 0.5
delta = 1
def now():
return dt.strftime(dt.now(), '%H:%M:%S')
The following is the function that performs sampling. Sampling is performed until the specified number of samples is reached. It is almost the same as Last time except that it is functionalized and the time measurement code is added. The key is to execute this function in parallel. Since it is sampling for each chain, it can move independently between processes, so it feels easy because there is no need for interprocess communication.
def exec_sampling(n_samples):
global b, delta
rd.seed(int(time.time())+os.getpid())
pid = os.getpid()
start = time.time()
start_time = now()
#initial state
sampling_result = []
current = np.array([5, 5])
sampling_result.append(current)
cnt = 1
while cnt < n_samples:
# rv from proposal distribution(Normal Dist: N(0, delta) )
next = current + rd.normal(0, delta, size=2)
r = P(next[0], next[1], b)/P(current[0], current[1], b)
if r > 1 or r > rd.uniform(0, 1):
# 0-When the uniform random number of 1 is larger than r, the state is updated.
current = copy.copy(next)
sampling_result.append(current)
cnt += 1
end = time.time()
end_time = now()
#Display of required time for each chain
print "PID:{}, exec time: {}, {}-{}".format(pid, end-start, start_time, end_time)
return sampling_result
The following three functions draw_scatter ()
, draw_traceplot ()
, and remove_burn_in_samples ()
are functions that process sampling results.
def draw_scatter(sample, alpha=0.3):
"""Draw a scatter plot of sampling results"""
plt.figure(figsize=(9,9))
plt.scatter(sample[:,0], sample[:,1], alpha=alpha)
plt.title("Scatter plot of 2-dim normal random variable with MCMC. sample size:{}".format(len(sample)))
plt.show()
def draw_traceplot(sample):
"""Draw a trace plot of sampling results"""
assert sample.shape[1] == 2
plt.figure(figsize=(15, 6))
for i in range(2):
plt.subplot(2, 1, i+1)
plt.xlim(0, len(sample[:,i]))
plt.plot(sample[:,i], lw=0.05)
if i == 0:
order = "1st"
else:
order = "2nd"
plt.title("Traceplot of {} parameter.".format(order))
plt.show()
def remove_burn_in_samples(total_sampling_result, burn_in_rate=0.2):
"""Burn-Exclude the sample of the section specified in in."""
adjust_burn_in_result = []
for i in xrange(len(total_sampling_result)):
idx = int(len(total_sampling_result[i])*burn_in_rate)
adjust_burn_in_result.extend(total_sampling_result[i][idx:])
return np.array(adjust_burn_in_result)
The following is a function that performs parallel processing. If you look closely, you can see that it is virtually the same as the first simple example.
def parallel_exec(n_samples, n_chain, burn_in_rate=0.2):
"""Execution of parallel processing"""
#Calculate sample size per chain
n_samples_per_chain = n_samples / float(n_chain)
print "Making {} samples per {} chain. Burn-in rate:{}".format(n_samples_per_chain, n_chain, burn_in_rate)
#Creating a Pool object
pool = Pool(processes=n_chain)
#Generate arguments for execution
n_trials_per_process = [n_samples_per_chain] * n_chain
#Execution of parallel processing
start = time.time() #measurement
total_sampling_result = pool.map(exec_sampling, n_trials_per_process)
end = time.time() #measurement
#Display of total required time
print "total exec time: {}".format(end-start)
# Drawing scatter plot
adjusted_samples = remove_burn_in_samples(total_sampling_result)
draw_scatter(adjusted_samples, alpha=0.01)
draw_traceplot(adjusted_samples)
pool.close()
Now let's see the actual effect. The number of samplings is 1,000,000, and the cases where the number of chains is 2 and 1 are measured.
#parameter: n_samples = 1000000, n_chain = 2
parallel_exec(1000000, 2)
Sampling is completed in less than 19 seconds in total, about 12 seconds per worker process.
out
Making 500000.0 samples per 2 chain. Burn-in rate:0.2
total exec time: 18.6980280876
PID:2374, exec time: 12.0037689209, 20:53:41-20:53:53
PID:2373, exec time: 11.9927477837, 20:53:41-20:53:53
#parameter: n_samples = 1000000, n_chain = 1
parallel_exec(1000000, 1)
Running in one worker process took less than 33 seconds. So if you run it in two processes, you can see that it runs 1.7 times faster: satisfied:
out
Making 1000000.0 samples per 1 chain. Burn-in rate:0.2
total exec time: 32.683218956
PID:2377, exec time: 24.7304420471, 20:54:07-20:54:31
Python Documentation (2.7ja1) 16.6. Multiprocessing — Process-based “parallel processing” interface http://docs.python.jp/2.7/library/multiprocessing.html
High Performance Python (O'Reilly) https://www.oreilly.co.jp/books/9784873117409/