{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "stim_data.pkl\n", "dict_keys(['stim_val', 'trial_stim_id', 'key_list', 'num_trial', 'trial_pair_id', 'pair_val', 'pair_trial_id', 'stim_id_trial', 'num_stim'])\n", "spike_data.pkl\n", "dict_keys(['spike_count_rate', 'avg_firing_rate', 'sem_firing_rate', 'firing_rate', 'stim_num_trial', 'C_r_fphi_theta', 'theta_hist', 'phase_hist', 'pair_hist'])\n", "corr_data.pkl\n", "dict_keys(['corr_stim_unit', 'optimal_avg_firing_rate', 'stim_hist', 'stim_hist_caution'])\n" ] } ], "source": [ "# load data 1\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import scipy.stats as sts\n", "import pickle\n", "import ipywidgets as widgets\n", "from mpl_toolkits import mplot3d\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from scipy import sparse\n", "from matplotlib import cm\n", "import cmocean\n", "\n", "cmap_phase = cmocean.cm.phase\n", "cmap_hot = cm.get_cmap('hot')\n", "cmap_viridis = cm.get_cmap('viridis')\n", "\n", "execfile('Stimulus.py')\n", "data_folder = \"Data/\"\n", "\n", "stim_file = \"Stiminfo_PVCre_2021_0012_s06_e14.csv\"\n", "stim = pd.read_csv(data_folder+stim_file)\n", "\n", "spike_times_file = \"Spiketimes_PVCre_2021_0012_s06_e14.npy\"\n", "spike_times = np.load(data_folder+spike_times_file, allow_pickle=True)\n", "active = [len(spike_times[i]) > 0 for i in range(len(spike_times))]\n", "spike_times = spike_times[np.where(active)]\n", "\n", "num_unit = len(spike_times)\n", "num_trial = len(stim)\n", "\n", "# sort by firing rate\n", "\n", "num_spike = list(map(len, spike_times))\n", "# num_spike = np.array([len(spike_times[i]) for i in range(len(spike_times))])\n", "spike_times = spike_times[np.argsort(num_spike)[::-1]]\n", "execfile('load.py')\n", "\n", "max_delay = 300 # dt\n", "tau_id_range = np.arange(max_delay)\n", "\n", "latest_spike_time = max([np.max(s) for s in spike_times if len(s)])\n", "latest_stim_offtime = list(stim['stim_offtime'])[-1]\n", "experiment_dur = max([latest_spike_time, latest_stim_offtime])\n", "\n", "dt = 0.001 # 1 ms\n", "exp_time = np.arange(0, experiment_dur, dt)\n", "M = len(exp_time)\n", "\n", "# binary spike and stimulus trains\n", "B_stim = {}\n", "for key in key_list:\n", " B_stim[key] = []\n", " for stim_id, trials in enumerate(stim_id_trial[key]):\n", " B_stim[key].append([])\n", " s = []\n", " for trial_id in trials:\n", " t_on, t_off = stim['stim_ontime'][trial_id], stim['stim_offtime'][trial_id]\n", " s += list(np.arange(int(t_on//dt), int(t_off//dt)))\n", "\n", " B_stim[key][stim_id] = sparse.coo_matrix((np.ones(len(s)), (np.zeros(len(s), dtype=int), s)), shape=(1, M))\n", "s = spike_times//dt\n", "B_spike = []\n", "for unit_id in range(num_unit):\n", " B_spike.append(sparse.coo_matrix((np.ones(len(s[unit_id])), (np.zeros(len(s[unit_id]), dtype=int), np.int0(s[unit_id]))), shape=(1, M)))\n", "\n", "# histogram error bars: num spikes\n", "s = np.zeros((num_unit, 2))\n", "for unit_id in range(num_unit):\n", " # print(\"unit: %d\"%unit_id)\n", " a = np.zeros(len(tau_id_range))\n", " for tau_id in tau_id_range:\n", " a[tau_id] = np.sum(B_spike[unit_id].col >= tau_id)\n", " \n", " s[unit_id] = [np.mean(a), np.std(a)]\n", "\n", "key_symbol = {'pair':'$(\\\\theta,\\phi)$', 'orientation':'$\\\\theta$', 'phase':'$\\phi$'}\n", "\n", "# 2D tuning\n", "avg_firing_rate_pair = np.array([sts.zscore(stim_hist['pair'][unit_id]).reshape((len(tau_id_range), num_stim['orientation'], num_stim['phase'])) for unit_id in range(num_unit)])\n", "\n", "sorted_spike_num = np.sort(num_spike)[::-1]\n", "cutoff_num_spike = 1000\n", "num_unit = np.sum(sorted_spike_num > cutoff_num_spike)\n", "for key in key_list:\n", " stim_hist[key] = stim_hist[key][:num_unit]\n", "\n", "opt_time = np.load('optimal_time.npy')\n", "stim_val_pair = pd.DataFrame(stim_val['pair'], columns=['orientation', 'phase'])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "stim_data.pkl\n", "dict_keys(['stim_val', 'trial_stim_id', 'key_list', 'num_trial', 'trial_pair_id', 'pair_val', 'pair_trial_id', 'stim_id_trial', 'num_stim'])\n", "spike_data.pkl\n", "dict_keys(['spike_count_rate', 'avg_firing_rate', 'sem_firing_rate', 'firing_rate', 'stim_num_trial', 'C_r_fphi_theta', 'theta_hist', 'phase_hist', 'pair_hist'])\n", "corr_data.pkl\n", "dict_keys(['corr_stim_unit', 'optimal_avg_firing_rate', 'stim_hist', 'stim_hist_caution'])\n" ] } ], "source": [ "# load data 0\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import scipy.stats as sts\n", "import pickle\n", "import ipywidgets as widgets\n", "from mpl_toolkits import mplot3d\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from scipy import sparse\n", "from matplotlib import cm\n", "\n", "cmap_hot = cm.get_cmap('hot')\n", "cmap_viridis = cm.get_cmap('viridis')\n", "\n", "execfile('Stimulus.py')\n", "data_folder = \"Data/\"\n", "\n", "stim_file = \"Stiminfo_PVCre_2021_0012_s06_e14.csv\"\n", "stim = pd.read_csv(data_folder+stim_file)\n", "\n", "spike_times_file = \"Spiketimes_PVCre_2021_0012_s06_e14.npy\"\n", "spike_times = np.load(data_folder+spike_times_file, allow_pickle=True)\n", "active = [len(spike_times[i]) > 0 for i in range(len(spike_times))]\n", "spike_times = spike_times[np.where(active)]\n", "\n", "num_unit = len(spike_times)\n", "num_trial = len(stim)\n", "\n", "# sort by firing rate\n", "\n", "num_spike = list(map(len, spike_times))\n", "# num_spike = np.array([len(spike_times[i]) for i in range(len(spike_times))])\n", "spike_times = spike_times[np.argsort(num_spike)[::-1]]\n", "execfile('load.py')\n", "\n", "max_delay = 300 # dt\n", "tau_id_range = np.arange(max_delay)\n", "\n", "latest_spike_time = max([np.max(s) for s in spike_times if len(s)])\n", "latest_stim_offtime = list(stim['stim_offtime'])[-1]\n", "experiment_dur = max([latest_spike_time, latest_stim_offtime])\n", "\n", "dt = 0.001 # 1 ms\n", "exp_time = np.arange(0, experiment_dur, dt)\n", "M = len(exp_time)\n", "\n", "# binary spike and stimulus trains\n", "B_stim = {}\n", "for key in key_list:\n", " B_stim[key] = []\n", " for stim_id, trials in enumerate(stim_id_trial[key]):\n", " B_stim[key].append([])\n", " s = []\n", " for trial_id in trials:\n", " t_on, t_off = stim['stim_ontime'][trial_id], stim['stim_offtime'][trial_id]\n", " s += list(np.arange(int(t_on//dt), int(t_off//dt)))\n", "\n", " B_stim[key][stim_id] = sparse.coo_matrix((np.ones(len(s)), (np.zeros(len(s), dtype=int), s)), shape=(1, M))\n", "s = spike_times//dt\n", "B_spike = []\n", "for unit_id in range(num_unit):\n", " B_spike.append(sparse.coo_matrix((np.ones(len(s[unit_id])), (np.zeros(len(s[unit_id]), dtype=int), np.int0(s[unit_id]))), shape=(1, M)))\n", "\n", "# histogram error bars: num spikes\n", "s = np.zeros((num_unit, 2))\n", "for unit_id in range(num_unit):\n", " # print(\"unit: %d\"%unit_id)\n", " a = np.zeros(len(tau_id_range))\n", " for tau_id in tau_id_range:\n", " a[tau_id] = np.sum(B_spike[unit_id].col >= tau_id)\n", " \n", " s[unit_id] = [np.mean(a), np.std(a)]\n", "\n", "key_symbol = {'pair':'$(\\\\theta,\\phi)$', 'orientation':'$\\\\theta$', 'phase':'$\\phi$'}\n", "\n", "# 2D tuning\n", "avg_firing_rate_pair = np.array([sts.zscore(stim_hist['pair'][unit_id]).reshape((len(tau_id_range), num_stim['orientation'], num_stim['phase'])) for unit_id in range(num_unit)])\n", "\n", "\n", "import sys\n", "sys.path.insert(0, './Persistent_Homology')\n", "\n", "from gratings import grating_model\n", "from plotting import plot_data, plot_mean_against_index, show_feature\n", "from persistence import persistence\n", "from decoding import cohomological_parameterization, remove_feature\n", "from noisereduction import *\n", "from sklearn.decomposition import PCA \n", "\n", "# relabelling\n", "param = np.arange(400).reshape(20,20)\n", "phase_ref = 10\n", "x = param[:phase_ref]\n", "reparam = np.vstack((x, x[:,::-1]))\n", "stim_val_reparam = {}\n", "stim_val_reparam['phase'] = stim_val['phase'][:phase_ref]\n", "stim_val_reparam['orientation'] = np.concatenate((stim_val['orientation'], stim_val['orientation']+180))\n", "\n", "\n", "opt_time = np.load('optimal_time.npy')" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "stim_data.pkl\n", "dict_keys(['stim_val', 'trial_stim_id', 'key_list', 'num_trial', 'trial_pair_id', 'pair_val', 'pair_trial_id', 'stim_id_trial', 'num_stim'])\n", "spike_data.pkl\n", "dict_keys(['spike_count_rate', 'avg_firing_rate', 'sem_firing_rate', 'firing_rate', 'stim_num_trial', 'C_r_fphi_theta', 'theta_hist', 'phase_hist', 'pair_hist'])\n", "corr_data.pkl\n", "dict_keys(['corr_stim_unit', 'optimal_avg_firing_rate', 'stim_hist', 'stim_hist_caution'])\n" ] } ], "source": [ "# load data\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import scipy.stats as sts\n", "import pickle\n", "import ipywidgets as widgets\n", "from mpl_toolkits import mplot3d\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from scipy import sparse\n", "from matplotlib import cm\n", "\n", "# color maps\n", "cmap_hot = cm.get_cmap('hot')\n", "cmap_viridis = cm.get_cmap('viridis')\n", "cmap_jet = cm.get_cmap('jet')\n", "\n", "# load and process stimuli\n", "execfile('Stimulus.py')\n", "data_folder = \"Data/\"\n", "stim_file = \"Stiminfo_PVCre_2021_0012_s06_e14.csv\"\n", "stim = pd.read_csv(data_folder+stim_file)\n", "num_trial = len(stim)\n", "\n", "# load and process spikes\n", "spike_times_file = \"Spiketimes_PVCre_2021_0012_s06_e14.npy\"\n", "spike_times = np.load(data_folder+spike_times_file, allow_pickle=True)\n", "active = [len(spike_times[i]) > 0 for i in range(len(spike_times))]\n", "spike_times = spike_times[np.where(active)]\n", "num_unit = len(spike_times)\n", "\n", "\n", "# sort by firing rate\n", "num_spike = list(map(len, spike_times))\n", "spike_times = spike_times[np.argsort(num_spike)[::-1]]\n", "execfile('load.py')\n", "\n", "# reverse correlation time offset range \n", "max_delay = 300 # dt\n", "tau_id_range = np.arange(max_delay)\n", "\n", "# experiment duration\n", "latest_spike_time = max([np.max(s) for s in spike_times if len(s)])\n", "latest_stim_offtime = list(stim['stim_offtime'])[-1]\n", "experiment_dur = max([latest_spike_time, latest_stim_offtime])\n", "\n", "dt = 0.001 # 1 ms\n", "exp_time = np.arange(0, experiment_dur, dt)\n", "M = len(exp_time)\n", "\n", "# binary spike and stimulus trains\n", "B_stim = {}\n", "for key in key_list:\n", " B_stim[key] = []\n", " for stim_id, trials in enumerate(stim_id_trial[key]):\n", " B_stim[key].append([])\n", " s = []\n", " for trial_id in trials:\n", " t_on, t_off = stim['stim_ontime'][trial_id], stim['stim_offtime'][trial_id]\n", " s += list(np.arange(int(t_on//dt), int(t_off//dt)))\n", "\n", " B_stim[key][stim_id] = sparse.coo_matrix((np.ones(len(s)), (np.zeros(len(s), dtype=int), s)), shape=(1, M))\n", "s = spike_times//dt\n", "B_spike = []\n", "for unit_id in range(num_unit):\n", " B_spike.append(sparse.coo_matrix((np.ones(len(s[unit_id])), (np.zeros(len(s[unit_id]), dtype=int), np.int0(s[unit_id]))), shape=(1, M)))\n", "\n", "# histogram error bars: num spikes\n", "# spike_count = np.zeros((num_unit, 2))\n", "# for unit_id in range(num_unit):\n", "# # print(\"unit: %d\"%unit_id)\n", "# a = np.zeros(len(tau_id_range))\n", "# for tau_id in tau_id_range:\n", "# a[tau_id] = np.sum(B_spike[unit_id].col >= tau_id)\n", " \n", "# spike_count[unit_id] = [np.mean(a), np.std(a)]\n", "\n", "key_symbol = {'pair':'$(\\\\theta,\\phi)$', 'orientation':'$\\\\theta$', 'phase':'$\\phi$'}\n", "\n", "# 2D tuning\n", "avg_firing_rate_pair = np.array([sts.zscore(stim_hist['pair'][unit_id]).reshape((len(tau_id_range), num_stim['orientation'], num_stim['phase'])) for unit_id in range(num_unit)])\n", "\n", "sorted_spike_num = np.sort(num_spike)[::-1]\n", "cutoff_num_spike = 1000\n", "num_unit = np.sum(sorted_spike_num > cutoff_num_spike)\n", "for key in key_list:\n", " stim_hist[key] = stim_hist[key][:num_unit]\n", "\n", "\n", "from tqdm import trange\n", "import sys\n", "sys.path.insert(0, './Persistent_Homology')\n", "\n", "import gratings\n", "import decorator\n", "from noisereduction import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimal time independently for each unit" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Z = sts.zscore([stim_hist['pair'][unit_id, opt_time[unit_id], :].T for unit_id in range(num_unit)], axis=0) \n", "# taking zscore does not make sense since we are dealing with probability distributions as activity not firing rate\n", "# Z = Z[:, np.logical_not(np.isnan(Z[0]))]\n", "# n_components = len(Z[0])\n", "from sklearn.decomposition import PCA\n", "from scipy.stats import zscore\n", "n_components = 20\n", "pca = PCA(n_components)\n", "data = pd.DataFrame([stim_hist['pair'][unit_id, opt_time[unit_id], :] for unit_id in range(num_unit)]).T\n", "# X = pca.fit_transform(data)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PCA explained variance:\n", "[0.16646987 0.10646209 0.07980927 0.07637544 0.07305655 0.06312245\n", " 0.0520563 0.05141274 0.0436141 0.04114367 0.0392689 0.02982529\n", " 0.02595947 0.02514927 0.02247076 0.01964191 0.01509293 0.01312476\n", " 0.01038999 0.0080545 ]\n" ] } ], "source": [ "ncomp = 20\n", "data = PCA_reduction(data, ncomp)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "13e6d26adfaf4482833713f11abb7724", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", "
\n", " | decoding | \n", "phase | \n", "orientation | \n", "
---|---|---|---|
0 | \n", "0.437717 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "0.449213 | \n", "18.0 | \n", "0.0 | \n", "
2 | \n", "0.429066 | \n", "36.0 | \n", "0.0 | \n", "
3 | \n", "0.424594 | \n", "54.0 | \n", "0.0 | \n", "
4 | \n", "0.816654 | \n", "72.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
395 | \n", "0.438846 | \n", "270.0 | \n", "171.0 | \n", "
396 | \n", "0.441466 | \n", "288.0 | \n", "171.0 | \n", "
397 | \n", "0.462445 | \n", "306.0 | \n", "171.0 | \n", "
398 | \n", "0.672323 | \n", "324.0 | \n", "171.0 | \n", "
399 | \n", "0.677862 | \n", "342.0 | \n", "171.0 | \n", "
400 rows × 3 columns
\n", "\n", " | decoding | \n", "phase | \n", "orientation | \n", "
---|---|---|---|
0 | \n", "0.217130 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "0.329715 | \n", "18.0 | \n", "0.0 | \n", "
2 | \n", "0.244882 | \n", "36.0 | \n", "0.0 | \n", "
3 | \n", "0.260258 | \n", "54.0 | \n", "0.0 | \n", "
4 | \n", "0.781878 | \n", "72.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
395 | \n", "0.204446 | \n", "270.0 | \n", "171.0 | \n", "
396 | \n", "0.207642 | \n", "288.0 | \n", "171.0 | \n", "
397 | \n", "0.181627 | \n", "306.0 | \n", "171.0 | \n", "
398 | \n", "0.552974 | \n", "324.0 | \n", "171.0 | \n", "
399 | \n", "0.558591 | \n", "342.0 | \n", "171.0 | \n", "
400 rows × 3 columns
\n", "\n", " | decoding | \n", "
---|---|
261 | \n", "0.932910 | \n", "
396 | \n", "0.930706 | \n", "
322 | \n", "0.932025 | \n", "
253 | \n", "0.930178 | \n", "
113 | \n", "0.931438 | \n", "
... | \n", "... | \n", "
310 | \n", "0.930397 | \n", "
327 | \n", "0.932556 | \n", "
207 | \n", "0.929700 | \n", "
160 | \n", "0.929805 | \n", "
233 | \n", "0.931270 | \n", "
200 rows × 1 columns
\n", "