Blame LFPy-2.0.7/examples/example_mpi_2.py

7d68d07
# -*- coding: utf-8 -*-
7d68d07
"""
7d68d07
LFPs from a population of cells relying on MPI (Message Passing Interface)
7d68d07
7d68d07
Execution:
7d68d07
7d68d07
    <mpiexec> -n <processes> python example_mpi_2.py
7d68d07
7d68d07
Notes:
7d68d07
- on certain platforms and with mpirun, the --oversubscribe argument is needed
7d68d07
  to get more processes than the number of physical CPU cores.
7d68d07
7d68d07
Copyright (C) 2017 Computational Neuroscience Group, NMBU.
7d68d07
7d68d07
This program is free software: you can redistribute it and/or modify
7d68d07
it under the terms of the GNU General Public License as published by
7d68d07
the Free Software Foundation, either version 3 of the License, or
7d68d07
(at your option) any later version.
7d68d07
7d68d07
This program is distributed in the hope that it will be useful,
7d68d07
but WITHOUT ANY WARRANTY; without even the implied warranty of
7d68d07
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
7d68d07
GNU General Public License for more details.
7d68d07
"""
7d68d07
7d68d07
import numpy as np
7d68d07
import matplotlib.pyplot as plt
7d68d07
from matplotlib.collections import PolyCollection, LineCollection
7d68d07
import os
7d68d07
from os.path import join
7d68d07
import sys
7d68d07
if sys.version < '3':
7d68d07
    from urllib2 import urlopen
7d68d07
else:
7d68d07
    from urllib.request import urlopen
7d68d07
import zipfile
7d68d07
import LFPy
7d68d07
import neuron
7d68d07
from mpi4py import MPI
7d68d07
7d68d07
#initialize the MPI interface
7d68d07
COMM = MPI.COMM_WORLD
7d68d07
SIZE = COMM.Get_size()
7d68d07
RANK = COMM.Get_rank()
7d68d07
7d68d07
#set the numpy random seeds
7d68d07
global_seed = 1234
7d68d07
np.random.seed(global_seed)
7d68d07
7d68d07
def stationary_poisson(nsyn,lambd,tstart,tstop):
7d68d07
    ''' Generates nsyn stationary possion processes with rate lambda between tstart and tstop'''
7d68d07
    interval_s = (tstop-tstart)*.001
7d68d07
    spiketimes = []
7d68d07
    for i in range(nsyn):
7d68d07
        spikecount = np.random.poisson(interval_s*lambd)
7d68d07
        spikevec = np.empty(spikecount)
7d68d07
        if spikecount==0:
7d68d07
            spiketimes.append(spikevec)
7d68d07
        else:
7d68d07
            spikevec = tstart + (tstop-tstart)*np.random.random(spikecount)
7d68d07
            spiketimes.append(np.sort(spikevec)) #sort them too!
7d68d07
7d68d07
    return spiketimes
7d68d07
7d68d07
7d68d07
#Fetch Mainen&Sejnowski 1996 model files
7d68d07
if not os.path.isfile(join('cells', 'cells', 'j4a.hoc')) and RANK==0:
7d68d07
    #get the model files:
7d68d07
    u = urlopen('http://senselab.med.yale.edu/ModelDB/eavBinDown.asp?o=2488&a=23&mime=application/zip')
7d68d07
    localFile = open('patdemo.zip', 'w')
7d68d07
    localFile.write(u.read())
7d68d07
    localFile.close()
7d68d07
    #unzip:
7d68d07
    myzip = zipfile.ZipFile('patdemo.zip', 'r')
7d68d07
    myzip.extractall('.')
7d68d07
    myzip.close()
7d68d07
7d68d07
#resync MPI threads
7d68d07
COMM.Barrier()
7d68d07
7d68d07
# Define cell parameters
7d68d07
cell_parameters = {          # various cell parameters,
7d68d07
    'morphology' : join('cells', 'cells', 'j4a.hoc'), # Mainen&Sejnowski, 1996
7d68d07
    'cm' : 1.0,         # membrane capacitance
7d68d07
    'Ra' : 150,         # axial resistance
7d68d07
    'v_init' : -65.,    # initial crossmembrane potential
7d68d07
    'passive' : True,   # turn on passive mechanism for all sections
7d68d07
    'passive_parameters' : {'g_pas' : 1./30000, 'e_pas' : -65}, # passive params
7d68d07
    'nsegs_method' : 'lambda_f',
7d68d07
    'lambda_f' : 100.,
7d68d07
    'dt' : 2.**-3,      # simulation time step size
7d68d07
    'tstart' :  0.,     # start time of simulation, recorders start at t=0
7d68d07
    'tstop' : 300.,     # stop simulation at 200 ms. These can be overridden
7d68d07
                        # by setting these arguments i cell.simulation()
7d68d07
}
7d68d07
7d68d07
# Define synapse parameters
7d68d07
synapse_parameters = {
7d68d07
    'idx' : 0, # to be set later
7d68d07
    'e' : 0.,                   # reversal potential
7d68d07
    'syntype' : 'ExpSyn',       # synapse type
7d68d07
    'tau' : 5.,                 # syn. time constant
7d68d07
    'weight' : .001,            # syn. weight
7d68d07
    'record_current' : True,
7d68d07
}
7d68d07
7d68d07
# Define electrode parameters
7d68d07
point_electrode_parameters = {
7d68d07
    'sigma' : 0.3,      # extracellular conductivity
7d68d07
    'x' : 0.,  # electrode requires 1d vector of positions
7d68d07
    'y' : 0.,
7d68d07
    'z' : 0.,
7d68d07
}
7d68d07
7d68d07
# number of units
7d68d07
n_cells = 6
7d68d07
7d68d07
# assign cell positions
7d68d07
x_cell_pos = np.linspace(-250., 250., n_cells)
7d68d07
7d68d07
# default rotation around x and y axis
7d68d07
xy_rotations = dict(x=4.99, y=-4.33)
7d68d07
7d68d07
# rotations around z-axis
7d68d07
if RANK == 0:
7d68d07
    z_rotation = COMM.bcast(np.random.permutation(np.arange(0., np.pi,
7d68d07
                                                            np.pi / n_cells)),
7d68d07
                            root=0)
7d68d07
else:
7d68d07
    z_rotation = COMM.bcast(None, root=0)
7d68d07
7d68d07
7d68d07
#synaptic spike times drawn on RANK 0 distributed to all processes
7d68d07
n_pre_syn = 1000
7d68d07
if RANK == 0:
7d68d07
    pre_syn_sptimes = COMM.bcast(stationary_poisson(nsyn=n_pre_syn, lambd=5.,
7d68d07
                                                    tstart=0, tstop=300),
7d68d07
                                 root=0)
7d68d07
else:
7d68d07
    pre_syn_sptimes = COMM.bcast(None, root=0)
7d68d07
7d68d07
# number of synapses on each cell
7d68d07
n_synapses = 100
7d68d07
7d68d07
# indices for presynaptic spike trains for each neuron also picked on RANK 0
7d68d07
# and scattered (for illustrating purposes, not efficiency)
7d68d07
if RANK == 0:
7d68d07
    # set up len SIZE nested list for spike train IDs.
7d68d07
    pre_syn_ids = [[]]*SIZE
7d68d07
    for cell_id in range(n_cells):
7d68d07
        pre_syn_ids[cell_id % SIZE] += [np.random.permutation(np.arange(
7d68d07
                                                    n_pre_syn))[0:n_synapses]]
7d68d07
else:
7d68d07
    pre_syn_ids = None
7d68d07
pre_syn_ids = COMM.scatter(pre_syn_ids, root=0)
7d68d07
7d68d07
# containers for per-cell LFP and summed LFPs
7d68d07
single_LFPs = []
7d68d07
summed_LFP = np.zeros(int(cell_parameters['tstop'] / cell_parameters['dt'] + 1))
7d68d07
7d68d07
# get state of random seed generator
7d68d07
state = np.random.get_state()
7d68d07
7d68d07
# iterate over cells in populations
7d68d07
for cell_id in range(n_cells):
7d68d07
    if cell_id % SIZE == RANK:
7d68d07
        # get set seed per cell in order to synapse locations
7d68d07
        np.random.seed(global_seed + cell_id)
7d68d07
7d68d07
        # Create cell
7d68d07
        cell = LFPy.Cell(**cell_parameters)
7d68d07
7d68d07
        #Have to position and rotate the cells!
7d68d07
        cell.set_rotation(z=z_rotation[cell_id], **xy_rotations)
7d68d07
        cell.set_pos(x=x_cell_pos[cell_id])
7d68d07
7d68d07
        for i_syn in range(n_synapses):
7d68d07
            syn_idx = cell.get_rand_idx_area_norm()
7d68d07
            synapse_parameters.update({'idx' : syn_idx})
7d68d07
            synapse = LFPy.Synapse(cell, **synapse_parameters)
7d68d07
            synapse.set_spike_times(pre_syn_sptimes[pre_syn_ids[
7d68d07
                                                        cell_id % SIZE][i_syn]])
7d68d07
7d68d07
        #run the cell simulation
7d68d07
        cell.simulate(rec_imem=True)
7d68d07
7d68d07
        #set up the extracellular device
7d68d07
        point_electrode = LFPy.RecExtElectrode(cell,
7d68d07
                                               **point_electrode_parameters)
7d68d07
        point_electrode.calc_lfp()
7d68d07
7d68d07
        # sum LFP on this RANK
7d68d07
        summed_LFP += point_electrode.LFP[0]
7d68d07
7d68d07
        # send LFP of this cell to RANK 0
7d68d07
        if RANK != 0:
7d68d07
            COMM.send(point_electrode.LFP[0], dest=0)
7d68d07
        else:
7d68d07
            single_LFPs += [point_electrode.LFP[0]]
7d68d07
7d68d07
    # collect single LFP contributions on RANK 0
7d68d07
    if RANK == 0:
7d68d07
        if cell_id % SIZE != RANK:
7d68d07
            single_LFPs += [COMM.recv(source=cell_id % SIZE)]
7d68d07
7d68d07
# we can also use MPI to sum arrays directly:
7d68d07
summed_LFP = COMM.reduce(summed_LFP)
7d68d07
7d68d07
# reset state of random number generator
7d68d07
np.random.set_state(state)
7d68d07
7d68d07
# plot output on RANK 0.
7d68d07
if RANK==0:
7d68d07
    #assign color to each unit
7d68d07
    color_vec = [plt.cm.rainbow(int(x*256./n_cells)) for x in range(n_cells)]
7d68d07
7d68d07
    #figure
7d68d07
    fig = plt.figure(figsize=(12, 8))
7d68d07
7d68d07
    # Morphologies axes:
7d68d07
    plt.axes([.175, .0, .65, 1], aspect='equal')
7d68d07
    plt.axis('off')
7d68d07
7d68d07
    for i_cell in range(n_cells):
7d68d07
        cell = LFPy.Cell(join('cells', 'cells', 'j4a.hoc'),
7d68d07
                         nsegs_method='lambda_f',
7d68d07
                         lambda_f=5)
7d68d07
        cell.set_rotation(z=z_rotation[i_cell], **xy_rotations)
7d68d07
        cell.set_pos(x=x_cell_pos[i_cell])
7d68d07
7d68d07
        zips = []
7d68d07
        for x, z in cell.get_idx_polygons():
7d68d07
            zips.append(list(zip(x, z)))
7d68d07
        linecol = LineCollection(zips,
7d68d07
                    edgecolor = 'none',
7d68d07
                    facecolor = color_vec[i_cell],
7d68d07
                    rasterized=False,
7d68d07
                    )
7d68d07
7d68d07
        ax = plt.gca()
7d68d07
        ax.add_collection(linecol)
7d68d07
7d68d07
    axis = ax.axis(ax.axis('equal'))
7d68d07
    ax.axis(np.array(axis) / 1.15)
7d68d07
7d68d07
7d68d07
    #adding a blue dot:
7d68d07
    ax.plot(point_electrode.x, point_electrode.z, 'o',
7d68d07
            markeredgecolor='none', markerfacecolor='b', markersize=3,
7d68d07
            zorder=10, clip_on=False)
7d68d07
    plt.annotate("Electrode",
7d68d07
            xy=(0., 0.), xycoords='data',
7d68d07
            xytext=(-100., 1000.),
7d68d07
            arrowprops=dict(arrowstyle='wedge',
7d68d07
                            shrinkA=1,
7d68d07
                            shrinkB=1,
7d68d07
                            #lw=0.5,
7d68d07
                            mutation_scale=20,
7d68d07
                            fc="0.6", ec="none",
7d68d07
                            edgecolor='k', facecolor='w'))
7d68d07
7d68d07
    plt.xlim([-700., 700.])
7d68d07
7d68d07
    ax.plot([100, 200], [-250, -250], 'k', lw=1, clip_on=False)
7d68d07
    ax.text(150, -300, r'100$\mu$m', va='center', ha='center')
7d68d07
7d68d07
    #presynaptic spike trains axes
7d68d07
    plt.axes([.05, .35, .25, .55])
7d68d07
7d68d07
    pop_sptimes = []
7d68d07
    for i_pre in range(n_pre_syn):
7d68d07
        sp = pre_syn_sptimes[i_pre]
7d68d07
        for i_sp in range(len(sp)):
7d68d07
            pop_sptimes.append(sp[i_sp])
7d68d07
7d68d07
    for i_pre in range(n_pre_syn):
7d68d07
        plt.scatter(pre_syn_sptimes[i_pre],
7d68d07
                    i_pre*np.ones(len(pre_syn_sptimes[i_pre])),
7d68d07
                    s=1, edgecolors='none', facecolors='k')
7d68d07
7d68d07
    plt.ylim([0,n_pre_syn])
7d68d07
    plt.xlim([0,cell_parameters['tstop']])
7d68d07
    plt.ylabel('train #', ha='left', labelpad=0)
7d68d07
    plt.title('Presynaptic spike times')
7d68d07
7d68d07
    ax = plt.gca()
7d68d07
    for loc, spine in ax.spines.items():
7d68d07
        if loc in ['right', 'top']:
7d68d07
            spine.set_color('none')
7d68d07
    ax.xaxis.set_ticks_position('bottom')
7d68d07
    ax.yaxis.set_ticks_position('left')
7d68d07
7d68d07
    ax.set_xticklabels([])
7d68d07
7d68d07
    #spike rate axes
7d68d07
    plt.axes([.05,.12,.25,.2])
7d68d07
7d68d07
    binsize = 5
7d68d07
    bins=np.arange(0, cell_parameters['tstop']+1., binsize)
7d68d07
    count,b = np.histogram(pop_sptimes, bins=bins)
7d68d07
    rate = count*(1000./binsize)*(1./n_pre_syn)
7d68d07
    plt.plot(b[0:-1],rate,color='black',lw=1)
7d68d07
7d68d07
    plt.xlim([0,cell_parameters['tstop']])
7d68d07
    plt.ylim([0,10.])
7d68d07
7d68d07
    tvec = np.arange(point_electrode.LFP.shape[1])*cell.dt
7d68d07
7d68d07
    plt.xlabel('$t$ (ms)')
7d68d07
    plt.ylabel('Rate (spike/s)')
7d68d07
7d68d07
    ax = plt.gca()
7d68d07
    for loc, spine in ax.spines.items():
7d68d07
        if loc in ['right', 'top']:
7d68d07
            spine.set_color('none')
7d68d07
    ax.xaxis.set_ticks_position('bottom')
7d68d07
    ax.yaxis.set_ticks_position('left')
7d68d07
7d68d07
    #single neuron EPs axes
7d68d07
    plt.axes([.7,.35,.25,.55])
7d68d07
7d68d07
    plt.title('Single neuron extracellular potentials')
7d68d07
    plt.axis('off')
7d68d07
7d68d07
    for cell_id in range(n_cells):
7d68d07
        plt.plot(tvec,
7d68d07
                        cell_id+2.e3*single_LFPs[cell_id],
7d68d07
                        color=color_vec[cell_id], lw=1,
7d68d07
                        )
7d68d07
7d68d07
    plt.ylim([-1,n_cells-.5])
7d68d07
7d68d07
    #Summed LFPs axes
7d68d07
    plt.axes([.7,.12,.25,.2])
7d68d07
    plt.plot(tvec, 1E3*summed_LFP, color='black', lw=1)
7d68d07
    plt.ylim([-5.e-1,5e-1])
7d68d07
7d68d07
    plt.title('Summed extracellular potentials')
7d68d07
    plt.xlabel(r'$t$ (ms)')
7d68d07
    plt.ylabel(r'$\mu$V',ha='left',rotation='horizontal')
7d68d07
7d68d07
    ax = plt.gca()
7d68d07
    for loc, spine in ax.spines.items():
7d68d07
        if loc in ['right', 'top']:
7d68d07
            spine.set_color('none')
7d68d07
    ax.xaxis.set_ticks_position('bottom')
7d68d07
    ax.yaxis.set_ticks_position('left')
7d68d07
7d68d07
7d68d07
    fig.savefig('example_mpi_2.pdf', dpi=300)
7d68d07
    plt.show()