Lev
2 years ago
commit
7ee1b4320f
12 changed files with 4628 additions and 0 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -0,0 +1,188 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
Use cohomology to decode datasets with circular parameters |
||||||
|
|
||||||
|
Persistent homology from arxiv:1908.02518 |
||||||
|
Homological decoding from DOI:10.1007/s00454-011-9344-x and arxiv:1711.07205 |
||||||
|
""" |
||||||
|
import math |
||||||
|
import numpy as np |
||||||
|
from scipy.optimize import least_squares |
||||||
|
import pandas as pd |
||||||
|
|
||||||
|
from tqdm import trange |
||||||
|
|
||||||
|
import ripser |
||||||
|
|
||||||
|
from persistence import persistence |
||||||
|
|
||||||
|
EPSILON = 0.0000000000001 |
||||||
|
|
||||||
|
|
||||||
|
def shortest_cycle(graph, node2, node1): |
||||||
|
""" |
||||||
|
Returns the shortest cycle going through an edge |
||||||
|
|
||||||
|
Used for computing weights in decode |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
graph: ndarray (n_nodes, n_nodes) |
||||||
|
A matrix containing the weights of the edges in the graph |
||||||
|
node1: int |
||||||
|
The index of the first node of the edge |
||||||
|
node2: int |
||||||
|
The index of the second node of the edge |
||||||
|
|
||||||
|
Returns |
||||||
|
------- |
||||||
|
cycle: list of ints |
||||||
|
A list of indices representing the nodes of the cycle in order |
||||||
|
""" |
||||||
|
N = graph.shape[0] |
||||||
|
distances = np.inf * np.ones(N) |
||||||
|
distances[node2] = 0 |
||||||
|
prev_nodes = np.zeros(N) |
||||||
|
prev_nodes[:] = np.nan |
||||||
|
prev_nodes[node2] = node1 |
||||||
|
while (math.isnan(prev_nodes[node1])): |
||||||
|
distances_buffer = distances |
||||||
|
for j in range(N): |
||||||
|
possible_path_lengths = distances_buffer + graph[:,j] |
||||||
|
if (np.min(possible_path_lengths) < distances[j]): |
||||||
|
prev_nodes[j] = np.argmin(possible_path_lengths) |
||||||
|
distances[j] = np.min(possible_path_lengths) |
||||||
|
prev_nodes = prev_nodes.astype(int) |
||||||
|
cycle = [node1] |
||||||
|
while (cycle[0] != node2): |
||||||
|
cycle.insert(0,prev_nodes[cycle[0]]) |
||||||
|
cycle.insert(0,node1) |
||||||
|
return cycle |
||||||
|
|
||||||
|
def cohomological_parameterization(X ,cocycle_number=1, coeff=2,weighted=False): |
||||||
|
""" |
||||||
|
Compute an angular parametrization on the data set corresponding to a given |
||||||
|
1-cycle |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: ndarray(n_datapoints, n_features): |
||||||
|
Array containing the data |
||||||
|
cocycle_number: int, optional, default 1 |
||||||
|
An integer specifying the 1-cycle used |
||||||
|
The n-th most stable 1-cycle is used, where n = cocycle_number |
||||||
|
coeff: int prime, optional, default 1 |
||||||
|
The coefficient basis in which we compute the cohomology |
||||||
|
weighted: bool, optional, default False |
||||||
|
When true use a weighted graph for smoother parameterization |
||||||
|
as proposed in arxiv:1711.07205 |
||||||
|
|
||||||
|
Returns |
||||||
|
------- |
||||||
|
decoding: ndarray(n_datapoints) |
||||||
|
The parameterization of the dataset consisting of a number between |
||||||
|
0 and 1 for each datapoint, to be interpreted modulo 1 |
||||||
|
""" |
||||||
|
# Get the cocycle |
||||||
|
result = ripser.ripser(X, maxdim=1, coeff=coeff, do_cocycles=True) |
||||||
|
diagrams = result['dgms'] |
||||||
|
cocycles = result['cocycles'] |
||||||
|
D = result['dperm2all'] |
||||||
|
dgm1 = diagrams[1] |
||||||
|
idx = np.argsort(dgm1[:, 1] - dgm1[:, 0])[-cocycle_number] |
||||||
|
cocycle = cocycles[1][idx] |
||||||
|
persistence(X, homdim=1, coeff=coeff, show_largest_homology=0, |
||||||
|
Nsubsamples=0, save_path=None, cycle=idx) |
||||||
|
thresh = dgm1[idx, 1]-EPSILON |
||||||
|
|
||||||
|
# Compute connectivity |
||||||
|
N = X.shape[0] |
||||||
|
connectivity = np.zeros([N,N]) |
||||||
|
for i in range(N): |
||||||
|
for j in range(i): |
||||||
|
if D[i, j] <= thresh: |
||||||
|
connectivity[i,j] = 1 |
||||||
|
cocycle_array = np.zeros([N,N]) |
||||||
|
|
||||||
|
# Lift cocycle |
||||||
|
for i in range(cocycle.shape[0]): |
||||||
|
cocycle_array[cocycle[i,0],cocycle[i,1]] = ( |
||||||
|
((cocycle[i,2] + coeff/2) % coeff) - coeff/2 |
||||||
|
) |
||||||
|
|
||||||
|
# Weights |
||||||
|
if (weighted): |
||||||
|
def real_cocycle(x): |
||||||
|
real_cocycle =( |
||||||
|
connectivity * (cocycle_array + np.subtract.outer(x, x)) |
||||||
|
) |
||||||
|
return np.ravel(real_cocycle) |
||||||
|
|
||||||
|
# Compute graph |
||||||
|
x0 = np.zeros(N) |
||||||
|
res = least_squares(real_cocycle, x0) |
||||||
|
real_cocyle_array = res.fun |
||||||
|
real_cocyle_array = real_cocyle_array.reshape(N,N) |
||||||
|
real_cocyle_array = real_cocyle_array - np.transpose(real_cocyle_array) |
||||||
|
graph = np.array(real_cocyle_array>0).astype(float) |
||||||
|
graph[graph==0] = np.inf |
||||||
|
graph = (D + EPSILON) * graph # Add epsilon to avoid NaNs |
||||||
|
|
||||||
|
# Compute weights |
||||||
|
cycle_counts = np.zeros([N,N]) |
||||||
|
iterator = trange(0, N, position=0, leave=True) |
||||||
|
iterator.set_description("Computing weights for decoding") |
||||||
|
for i in iterator: |
||||||
|
for j in range(N): |
||||||
|
if (graph[i,j] != np.inf): |
||||||
|
cycle = shortest_cycle(graph, j, i) |
||||||
|
for k in range(len(cycle)-1): |
||||||
|
cycle_counts[cycle[k], cycle[k+1]] += 1 |
||||||
|
|
||||||
|
weights = cycle_counts / (D + EPSILON)**2 |
||||||
|
weights = np.sqrt(weights) |
||||||
|
else: |
||||||
|
weights = np.outer(np.ones(N),np.ones(N)) |
||||||
|
|
||||||
|
def real_cocycle(x): |
||||||
|
real_cocycle =( |
||||||
|
weights * connectivity * (cocycle_array + np.subtract.outer(x, x)) |
||||||
|
) |
||||||
|
return np.ravel(real_cocycle) |
||||||
|
|
||||||
|
# Smooth cocycle |
||||||
|
print("Decoding...", end=" ") |
||||||
|
x0 = np.zeros(N) |
||||||
|
res = least_squares(real_cocycle, x0) |
||||||
|
decoding = res.x |
||||||
|
decoding = np.mod(decoding, 1) |
||||||
|
print("done") |
||||||
|
|
||||||
|
decoding = pd.DataFrame(decoding, columns=["decoding"]) |
||||||
|
decoding = decoding.set_index(X.index) |
||||||
|
return decoding |
||||||
|
|
||||||
|
|
||||||
|
def remove_feature(X, decoding, shift=0, cut_amplitude=1.0): |
||||||
|
""" |
||||||
|
Removes a decoded feature from a dataset by making a cut at a fixed value |
||||||
|
of the decoding |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Array containing the data |
||||||
|
decoding : dataframe(n_datapoints) |
||||||
|
The decoded feature, assumed to be angular with periodicity 1 |
||||||
|
shift : float between 0 and 1, optional, default 0 |
||||||
|
The location of the cut |
||||||
|
cut_amplitude : float, optional, default 1 |
||||||
|
Amplitude of the cut |
||||||
|
""" |
||||||
|
cuts = np.zeros(X.shape) |
||||||
|
decoding = decoding.to_numpy()[:,0] |
||||||
|
for i in range(X.shape[1]): |
||||||
|
effective_amplitude = cut_amplitude * (np.max(X[i]) - np.min(X[i])) |
||||||
|
cuts[:,i] = effective_amplitude * ((decoding - shift) % 1) |
||||||
|
reduced_data = X + cuts |
||||||
|
return reduced_data |
@ -0,0 +1,44 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
Some decorators useful for data analysis functions |
||||||
|
""" |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
def multi_input(f): |
||||||
|
"""Allow a function to also be applied to each element in a dictionary""" |
||||||
|
def wrapper(data, *args, **kwargs): |
||||||
|
if type(data) is dict: |
||||||
|
output_data = {} |
||||||
|
for name in data: |
||||||
|
output_data[name] = f(data[name], *args, **kwargs) |
||||||
|
if all(x is None for x in output_data.values()): |
||||||
|
return |
||||||
|
else: |
||||||
|
return output_data |
||||||
|
else: |
||||||
|
return f(data, *args, **kwargs) |
||||||
|
return wrapper |
||||||
|
|
||||||
|
def av_output(f): |
||||||
|
"""Allow running a function multiple times returning the average output""" |
||||||
|
def wrapper(average=1, *args, **kwargs): |
||||||
|
data_av = f(*args, **kwargs) |
||||||
|
try: |
||||||
|
if not isinstance(data_av, np.ndarray): |
||||||
|
data_av = np.array(data_av) |
||||||
|
except: |
||||||
|
pass |
||||||
|
for i in range(average - 1): |
||||||
|
delta = f(*args, **kwargs) |
||||||
|
try: |
||||||
|
if not isinstance(delta, np.ndarray): |
||||||
|
delta = np.array(delta) |
||||||
|
except: |
||||||
|
pass |
||||||
|
data_av += delta |
||||||
|
if not isinstance(data_av, tuple): |
||||||
|
data_av /= average |
||||||
|
else: |
||||||
|
data_av = [d / average for d in data_av] |
||||||
|
return data_av |
||||||
|
return wrapper |
@ -0,0 +1,61 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
from tqdm import trange |
||||||
|
|
||||||
|
import matplotlib.pyplot as plt |
||||||
|
|
||||||
|
from decorators import multi_input |
||||||
|
|
||||||
|
|
||||||
|
@multi_input |
||||||
|
def estimate_dimension(X, max_size, test_size = 30, Nsteps = 20, fraction = 0.5): |
||||||
|
""" |
||||||
|
Plots an estimation of the dimension of a dataset at different scales |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
max_size : float |
||||||
|
The upper bound for the scale |
||||||
|
test_size : int, optional, default 30 |
||||||
|
The number of datapoints used to estimate the density |
||||||
|
Nsteps : int, optional, default 20 |
||||||
|
The number of different scales at which the density is estimated |
||||||
|
fraction : float between 0 and 1, optional, default 0.5 |
||||||
|
Difference in radius between the large sphere and smaller sphere used to compute density |
||||||
|
|
||||||
|
Returns |
||||||
|
------- |
||||||
|
average : ndarray(Nsteps) |
||||||
|
The dimension at each scale |
||||||
|
|
||||||
|
""" |
||||||
|
average = np.zeros(Nsteps) |
||||||
|
S = X.iloc[np.random.choice(X.shape[0], test_size, replace=False)] |
||||||
|
|
||||||
|
iterator = trange(0, Nsteps, position=0, leave=True) |
||||||
|
iterator.set_description("Estimating dimension") |
||||||
|
for n in iterator: |
||||||
|
size = max_size*n/Nsteps |
||||||
|
count_small = np.zeros(X.shape[0]) |
||||||
|
count_large = np.zeros(X.shape[0]) |
||||||
|
dimension = np.zeros(S.shape[0]) |
||||||
|
for i in range(0,S.shape[0]): |
||||||
|
for j in range(0,X.shape[0]): |
||||||
|
distance = np.sqrt(np.square(S.iloc[i] - X.iloc[j]).sum()) |
||||||
|
if (distance < size/fraction): |
||||||
|
count_large[i] += 1 |
||||||
|
if (distance < size): |
||||||
|
count_small[i] += 1 |
||||||
|
if (count_large[i] != 0): |
||||||
|
dimension[i] = np.log(count_small[i]/count_large[i])/np.log(fraction) |
||||||
|
else: |
||||||
|
dimension[i] = 0 |
||||||
|
average[n] = np.mean(dimension) |
||||||
|
plt.plot(range(0, Nsteps), average) |
||||||
|
plt.xlabel("Scale") |
||||||
|
plt.ylabel("Dimension") |
||||||
|
plt.show() |
||||||
|
return average |
@ -0,0 +1,43 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import pandas as pd |
||||||
|
|
||||||
|
import sys |
||||||
|
sys.path.insert(0, './Modules') |
||||||
|
|
||||||
|
from gratings import grating_model |
||||||
|
from plotting import plot_data, plot_mean_against_index, show_feature |
||||||
|
from persistence import persistence |
||||||
|
from decoding import cohomological_parameterization, remove_feature |
||||||
|
from noisereduction import PCA_reduction, z_cutoff |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Generate data |
||||||
|
data = grating_model(Nn=8, Np=(18,1,18,1), deltaT=55, random_neurons=True) |
||||||
|
|
||||||
|
## Apply noise reduction |
||||||
|
# data = PCA_reduction(data, 5) |
||||||
|
# data = z_cutoff(data,2) |
||||||
|
|
||||||
|
## Analyze shape |
||||||
|
persistence(data,homdim=2,coeff=2) |
||||||
|
persistence(data,homdim=2,coeff=3) |
||||||
|
|
||||||
|
## Decode first parameter |
||||||
|
decoding1 = cohomological_parameterization(data, coeff=23) |
||||||
|
show_feature(decoding1) |
||||||
|
plot_mean_against_index(data,decoding1,"orientation") |
||||||
|
plot_mean_against_index(data,decoding1,"phase") |
||||||
|
# plot_data(data,transformation="PCA", labels=decoding1, |
||||||
|
# colors=["Twilight","Viridis","Twilight","Viridis","Twilight"]) |
||||||
|
|
||||||
|
## Decode second parameter |
||||||
|
# reduced_data = remove_feature(data, decoding1, cut_amplitude=0.5) |
||||||
|
# decoding2 = cohomological_parameterization(reduced_data, coeff=23) |
||||||
|
# show_feature(decoding2) |
||||||
|
# plot_mean_against_index(data,decoding2,"orientation") |
||||||
|
# plot_mean_against_index(data,decoding2,"phase") |
||||||
|
# plot_data(data,transformation="PCA", labels=decoding2, |
||||||
|
# colors=["Twilight","Viridis","Twilight","Viridis","Twilight"]) |
@ -0,0 +1,225 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
Simulation of simple cells responding to grating images |
||||||
|
""" |
||||||
|
import numpy as np |
||||||
|
from numpy.random import poisson |
||||||
|
import pandas as pd |
||||||
|
from scipy.integrate import dblquad |
||||||
|
from itertools import product |
||||||
|
|
||||||
|
from collections import namedtuple |
||||||
|
from tqdm import trange |
||||||
|
from numba import njit |
||||||
|
|
||||||
|
import matplotlib.pyplot as plt |
||||||
|
|
||||||
|
from decorators import av_output |
||||||
|
|
||||||
|
GRATING_PARS = ["orientation", "frequency", "phase", "contrast"] |
||||||
|
Grating = namedtuple("Grating", GRATING_PARS) |
||||||
|
|
||||||
|
@njit |
||||||
|
def gabor_function(x, y, grating, sigma=None): |
||||||
|
"""Returns the value of a grating function at given x and y coordinates""" |
||||||
|
theta, f, phi, C = grating |
||||||
|
if sigma is None: |
||||||
|
sigma = 1.5/f |
||||||
|
|
||||||
|
return C*np.exp(-1/(2*sigma**2) * (x**2 + y**2))*np.cos(2*np.pi*f*(x*np.cos(theta)+ y*np.sin(theta)) + phi) |
||||||
|
|
||||||
|
def grating_function(x, y, grating): |
||||||
|
"""Returns the value of a grating function at given x and y coordinates""" |
||||||
|
smallest_distance = 0.1 |
||||||
|
theta, f, phi, C = grating |
||||||
|
return C*np.exp(-1/2*f**2*smallest_distance**2)*np.cos(2*np.pi*f*(x*np.cos(theta)+ y*np.sin(theta)) + phi) |
||||||
|
|
||||||
|
def grating_image(grating, gabor=False, rf_sigma=None, center=(0,0), N = 50, plot=True): |
||||||
|
""" |
||||||
|
Make an image of a grating |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
grating: Grating |
||||||
|
A tuple containing the orientation, frequency, phase and contrast |
||||||
|
N: int, optional, default 50 |
||||||
|
The number of pixels in each direction |
||||||
|
plot: bool, optional, default True |
||||||
|
When true plot the image |
||||||
|
|
||||||
|
Returns |
||||||
|
------- |
||||||
|
image: ndarray(N, N) |
||||||
|
An array of floats corresponding to the pixel values of the image |
||||||
|
""" |
||||||
|
|
||||||
|
if gabor: |
||||||
|
func = lambda x,y :gabor_function(x-center[0], y-center[1], grating, sigma=rf_sigma) |
||||||
|
else: |
||||||
|
func = lambda x,y :grating_function(X[i],X[j], grating) |
||||||
|
|
||||||
|
X = np.linspace(-1, 1, N) |
||||||
|
image = np.zeros([N,N]) |
||||||
|
for i in range(0,N): |
||||||
|
for j in range(0,N): |
||||||
|
image[i,j] = func(X[i],X[j]) |
||||||
|
|
||||||
|
if (plot): |
||||||
|
plt.imshow(image,"gray", vmin = -1, vmax = 1) |
||||||
|
plt.show() |
||||||
|
return image |
||||||
|
|
||||||
|
def angular_mean(X, period=2*np.pi, axis=None): |
||||||
|
""" |
||||||
|
Average over an angular variable |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: ndarray |
||||||
|
Array of angles to average over |
||||||
|
period: float, optional, default 2*pi |
||||||
|
The period of the angles |
||||||
|
axis: int, optional, default None |
||||||
|
The axis of X to average over |
||||||
|
""" |
||||||
|
ang_mean = ( |
||||||
|
(period/(2*np.pi)) |
||||||
|
* np.angle(np.mean(np.exp((2*np.pi/period) * X*1j),axis=axis)) |
||||||
|
% period |
||||||
|
) |
||||||
|
return ang_mean |
||||||
|
|
||||||
|
@njit |
||||||
|
def sigmoid(x): |
||||||
|
"""Sigmoid function""" |
||||||
|
return 1/(1+np.exp(1*(1/2-x))) |
||||||
|
|
||||||
|
def response(grating1, grating2, rf_sigma=None, center=(0,0)): |
||||||
|
""" |
||||||
|
Neural response of a simple cell to a grating image |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
grating1: Grating |
||||||
|
Grating defining the simple cell |
||||||
|
grating2: Grating |
||||||
|
Grating corresponding to the image shown |
||||||
|
center: (float,float), optional, default (0,0) |
||||||
|
The focal point |
||||||
|
""" |
||||||
|
fun1 = lambda s,t : gabor_function(s - center[0], t - center[1], grating1, sigma=rf_sigma) |
||||||
|
fun2 = lambda s,t : grating_function(s ,t, grating2) |
||||||
|
product = lambda s,t : fun1(s,t) * fun2(s,t) |
||||||
|
integral = dblquad(product, -1, 1, -1, 1, epsabs=0.01)[0] |
||||||
|
response = sigmoid(integral) |
||||||
|
return response |
||||||
|
|
||||||
|
def get_locations(N): |
||||||
|
""" |
||||||
|
Return uniformly distributed locations on the grating parameter space |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
N: (int,int,int,int) |
||||||
|
The number of different orientations, frequencies, phases and contrasts |
||||||
|
respectively |
||||||
|
""" |
||||||
|
N_or, N_fr, N_ph, N_co = N |
||||||
|
|
||||||
|
if (N_or==1): |
||||||
|
orientation = [0] |
||||||
|
else: |
||||||
|
if (N_ph==1): |
||||||
|
orientation = np.linspace(0.0, 2 * (1-1/N_or) * np.pi, N_or) |
||||||
|
else: |
||||||
|
orientation = np.linspace(0.0, (1-1/N_or) * np.pi, N_or) |
||||||
|
|
||||||
|
if (N_fr==1): |
||||||
|
frequency=[1.5] |
||||||
|
else: |
||||||
|
frequency = np.linspace(0.0, 9, N_fr) |
||||||
|
|
||||||
|
if (N_ph==1): |
||||||
|
phase = [np.pi/4] |
||||||
|
else: |
||||||
|
phase = np.linspace(0.0, 2 * (1-1/N_ph) * np.pi, N_ph) |
||||||
|
|
||||||
|
if (N_co==1): |
||||||
|
contrast = [1] |
||||||
|
else: |
||||||
|
contrast = np.linspace(0.0, 1.0, N_co) |
||||||
|
|
||||||
|
locations = list(product(orientation, frequency, phase, contrast)) |
||||||
|
return locations |
||||||
|
|
||||||
|
@av_output |
||||||
|
def grating_model(Nn, Np, rf_sigma=None, |
||||||
|
deltaT=None, random_focal_points=False, plot_stimuli=False): |
||||||
|
""" |
||||||
|
Simulate the firing of simple cells responding to images of gratings |
||||||
|
|
||||||
|
Simple cells and stimuli are uniformly distributed along the parameter space |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
Nn: int |
||||||
|
The number of different orientations, frequencies, phases and contrasts |
||||||
|
for the neurons |
||||||
|
Directions where the stimuli do not vary will not be included |
||||||
|
Np: (int,int,int,int) |
||||||
|
The number of different orientations, frequencies, phases and contrasts |
||||||
|
respectively for the stimuli |
||||||
|
rf_sigma: float, optional, default 5 |
||||||
|
The width of the simple cell receptive fields |
||||||
|
deltaT: float, optional, default None |
||||||
|
The time period spikes are sampled over for each stimulus |
||||||
|
When None return the exact firing rates instead |
||||||
|
random_focal_points: bool, optional, default False |
||||||
|
If true randomize the focal point for each stimulus |
||||||
|
plot_stimuli: bool, optional, default False |
||||||
|
If true plot an image of each stimulus |
||||||
|
average: int, optional, default=1 |
||||||
|
The number of times the simulation is repeated and averaged over |
||||||
|
|
||||||
|
Returns |
||||||
|
------- |
||||||
|
data: dataframe(n_datapoints, n_neurons) |
||||||
|
The simulated firing rate data |
||||||
|
""" |
||||||
|
|
||||||
|
Points = get_locations(Np) |
||||||
|
Neurons = get_locations(([Nn,1][Np[0]==1], |
||||||
|
[Nn,1][Np[1]==1], |
||||||
|
[Nn,1][Np[2]==1], |
||||||
|
[Nn,1][Np[3]==1])) |
||||||
|
|
||||||
|
# Set focal points |
||||||
|
if random_focal_points: |
||||||
|
focal_points = np.random.random([len(Points),2])*2-1 |
||||||
|
else: |
||||||
|
focal_points = np.zeros([len(Points),2]) |
||||||
|
|
||||||
|
# Compute firing rates |
||||||
|
rates = np.zeros([len(Points), len(Neurons)]) |
||||||
|
iterator = trange(0, len(Points), position=0, leave=True) |
||||||
|
iterator.set_description("Simulating data points") |
||||||
|
for i in iterator: |
||||||
|
if (i % 1 == 0): |
||||||
|
if plot_stimuli: |
||||||
|
grating_image(Points[i]) |
||||||
|
for j in range(0, len(Neurons)): |
||||||
|
rates[i,j] = response(Points[i], Neurons[j], rf_sigma=rf_sigma, center=focal_points[i]) |
||||||
|
|
||||||
|
# Add noise |
||||||
|
if deltaT is None: |
||||||
|
data = rates |
||||||
|
else: |
||||||
|
data = poisson(rates * deltaT) |
||||||
|
|
||||||
|
data = pd.DataFrame(data) |
||||||
|
data = pd.merge(pd.DataFrame(Points, columns = GRATING_PARS), |
||||||
|
data,left_index=True,right_index=True) |
||||||
|
data = data.set_index(GRATING_PARS) |
||||||
|
|
||||||
|
return data, focal_points |
||||||
|
|
@ -0,0 +1,11 @@ |
|||||||
|
import pickle |
||||||
|
|
||||||
|
|
||||||
|
def pkl_load(filename): |
||||||
|
with open(filename, 'rb') as f: |
||||||
|
return pickle.load(f) |
||||||
|
|
||||||
|
|
||||||
|
def pkl_save(filename, obj): |
||||||
|
with open(filename, 'wb') as f: |
||||||
|
pickle.dump(obj, f) |
@ -0,0 +1,179 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
A collection of noise reduction algorithms |
||||||
|
""" |
||||||
|
import numpy as np |
||||||
|
import scipy |
||||||
|
import pandas as pd |
||||||
|
|
||||||
|
from matplotlib import pyplot |
||||||
|
from mpl_toolkits.mplot3d import Axes3D |
||||||
|
|
||||||
|
from sklearn.decomposition import PCA |
||||||
|
|
||||||
|
from tqdm import trange |
||||||
|
from numba import njit, prange |
||||||
|
|
||||||
|
from persistence import persistence |
||||||
|
from decorators import multi_input |
||||||
|
|
||||||
|
|
||||||
|
@njit(parallel=True) |
||||||
|
def compute_gradient(S, X, sigma, omega): |
||||||
|
"""Compute gradient of F as in arxiv:0910.5947""" |
||||||
|
gradF = np.zeros(S.shape) |
||||||
|
d = X.shape[1] |
||||||
|
N = X.shape[0] |
||||||
|
M = S.shape[0] |
||||||
|
for j in range(0,M): |
||||||
|
normsSX = np.square(S[j] - X).sum(axis=1) |
||||||
|
normsSS = np.square(S[j] - S).sum(axis=1) |
||||||
|
expsSX = np.exp(-1/(2*sigma**2)*normsSX) |
||||||
|
expsSS = np.exp(-1/(2*sigma**2)*normsSS) |
||||||
|
SX, SS = np.zeros(d), np.zeros(d) |
||||||
|
for k in range(0,d): |
||||||
|
SX[k] = -1/(N*sigma**2) * np.sum((S[j] - X)[:,k] * expsSX) |
||||||
|
SS[k] = omega/(M*sigma**2) * np.sum((S[j] - S)[:,k] * expsSS) |
||||||
|
gradF[j] = SX + SS |
||||||
|
return gradF |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def top_noise_reduction(X, n=100, omega=0.2, fraction=0.1, plot=False): |
||||||
|
""" |
||||||
|
Topological denoising algorithm as in arxiv:0910.5947 |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
n: int, optional, default 100 |
||||||
|
Number of iterations |
||||||
|
omega: float, optional, default 0.2 |
||||||
|
Strength of the repulsive force between datapoints |
||||||
|
fraction: float between 0 and 1, optional, default 0.1 |
||||||
|
The fraction of datapoints from which the denoised dataset is |
||||||
|
constructed |
||||||
|
plot: bool, optional, default False |
||||||
|
When true plot the dataset and homology each iteration |
||||||
|
""" |
||||||
|
N = X.shape[0] |
||||||
|
S = X.iloc[np.random.choice(N, round(fraction*N), replace=False)] |
||||||
|
sigma = X.stack().std() |
||||||
|
c = 0.02*np.max(scipy.spatial.distance.cdist(X, X, metric='euclidean')) |
||||||
|
|
||||||
|
iterator = trange(0, n, position=0, leave=True) |
||||||
|
iterator.set_description("Topological noise reduction") |
||||||
|
for i in iterator: |
||||||
|
gradF = compute_gradient(S.to_numpy(), X.to_numpy(), sigma, omega) |
||||||
|
|
||||||
|
if i == 0: |
||||||
|
maxgradF = np.max(np.sqrt(np.square(gradF).sum(axis=1))) |
||||||
|
S = S + c* gradF/maxgradF |
||||||
|
|
||||||
|
if plot: |
||||||
|
fig = pyplot.figure() |
||||||
|
ax = Axes3D(fig) |
||||||
|
ax.scatter(X[0],X[1],X[2],alpha=0.1) |
||||||
|
ax.scatter(S[0],S[1],S[2]) |
||||||
|
pyplot.show() |
||||||
|
return S |
||||||
|
|
||||||
|
@njit(parallel=True) |
||||||
|
def density_estimation(X,k): |
||||||
|
"""Estimates density at each point""" |
||||||
|
N = X.shape[0] |
||||||
|
densities = np.zeros(N) |
||||||
|
for i in prange(N): |
||||||
|
distances = np.sum((X[i] - X)**2, axis=1) |
||||||
|
densities[i] = 1/np.sort(distances)[k] |
||||||
|
return densities |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def density_filtration(X, k, fraction): |
||||||
|
""" |
||||||
|
Returns the points which are in locations with high density |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
k: int |
||||||
|
Density is estimated as 1 over the distance to the k-th nearest point |
||||||
|
fraction: float between 0 and 1 |
||||||
|
The fraction of highedst density datapoints that are returned |
||||||
|
""" |
||||||
|
print("Applying density filtration...", end=" ") |
||||||
|
N = X.shape[0] |
||||||
|
X["densities"] = density_estimation(X.to_numpy().astype(np.float),k) |
||||||
|
X = X.nlargest(int(fraction * N), "densities") |
||||||
|
X = X.drop(columns="densities") |
||||||
|
print("done") |
||||||
|
return X |
||||||
|
|
||||||
|
@njit(parallel=True) |
||||||
|
def compute_averages(X, r): |
||||||
|
"""Used in neighborhood_average""" |
||||||
|
N = X.shape[0] |
||||||
|
averages = np.zeros(X.shape) |
||||||
|
for i in prange(N): |
||||||
|
distances = np.sum((X[i] - X)**2, axis=1) |
||||||
|
neighbors = X[distances < r] |
||||||
|
averages[i] = np.sum(neighbors, axis=0)/len(neighbors) |
||||||
|
return averages |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def neighborhood_average(X, r): |
||||||
|
""" |
||||||
|
Replace each point by an average over its neighborhood |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
r : float |
||||||
|
Points are averaged over all points within radius r |
||||||
|
""" |
||||||
|
print("Applying neighborhood average...", end=" ") |
||||||
|
averages = compute_averages(X.to_numpy().astype(np.float),r) |
||||||
|
print("done") |
||||||
|
result = pd.DataFrame(data=averages,index=X.index) |
||||||
|
return result |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def z_cutoff(X, z_cutoff): |
||||||
|
""" |
||||||
|
Remove outliers with a high Z-score |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
z_cutoff : float |
||||||
|
The Z-score at which points are removed |
||||||
|
""" |
||||||
|
z=np.abs(scipy.stats.zscore(np.sqrt(np.square(X).sum(axis=1)))) |
||||||
|
result = X[(z < z_cutoff)] |
||||||
|
print(f"{len(X) - len(result)} datapoints with Z-score above {z_cutoff}" |
||||||
|
+ "removed") |
||||||
|
return result |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def PCA_reduction(X, dim): |
||||||
|
""" |
||||||
|
Use principle component analysis to reduce the data to a lower dimension |
||||||
|
|
||||||
|
Also print the variance explained by each component |
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
dim : int |
||||||
|
The number of dimensions the data is reduced to |
||||||
|
""" |
||||||
|
pca = PCA(n_components=dim) |
||||||
|
pca.fit(X) |
||||||
|
columns = [i for i in range(dim)] |
||||||
|
X = pd.DataFrame(pca.transform(X), columns=columns, index=X.index) |
||||||
|
print("PCA explained variance:") |
||||||
|
print(pca.explained_variance_ratio_) |
||||||
|
return X |
@ -0,0 +1,169 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
Tools to compute persistence diagrams |
||||||
|
|
||||||
|
Persistent homology from ripser and gudhi library |
||||||
|
Confidence sets from arxiv:1303.7117 |
||||||
|
""" |
||||||
|
import numpy as np |
||||||
|
from scipy.spatial.distance import directed_hausdorff |
||||||
|
|
||||||
|
import matplotlib.pyplot as plt |
||||||
|
|
||||||
|
from tqdm import trange |
||||||
|
|
||||||
|
import ripser |
||||||
|
from persim import plot_diagrams |
||||||
|
import gudhi |
||||||
|
|
||||||
|
from decorators import multi_input |
||||||
|
|
||||||
|
|
||||||
|
def hausdorff(data1, data2, homdim, coeff): |
||||||
|
"""Hausdorff metric between two persistence diagrams""" |
||||||
|
dgm1 = (ripser.ripser(data1,maxdim=homdim,coeff=coeff))['dgms'] |
||||||
|
dgm2 = (ripser.ripser(data2,maxdim=homdim,coeff=coeff))['dgms'] |
||||||
|
distance = directed_hausdorff(dgm1[homdim], dgm2[homdim])[0] |
||||||
|
return distance |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def confidence(X, alpha=0.05, Nsubsamples=20, homdim=1, coeff=2): |
||||||
|
""" |
||||||
|
Compute the confidence interval of the persistence diagram of a dataset |
||||||
|
|
||||||
|
Computation done by subsampling as in arxiv:1303.7117 |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
alpha : float between 0 and 1, optional, default 0.05 |
||||||
|
1-alpha is the confidence |
||||||
|
Nsubsamples : int, optional, default 20 |
||||||
|
The number of subsamples |
||||||
|
homdim : int, optional, default 1 |
||||||
|
The dimension of the homology |
||||||
|
coeff : int prime, optional, default 2 |
||||||
|
The coefficient basis |
||||||
|
""" |
||||||
|
N = X.shape[0] |
||||||
|
distances = np.zeros(Nsubsamples) |
||||||
|
iterator = trange(0, Nsubsamples, position=0, leave=True) |
||||||
|
iterator.set_description("Computing confidence interval") |
||||||
|
for i in iterator: |
||||||
|
subsample = X.iloc[np.random.choice(N, N, replace=True)] |
||||||
|
distances[i] = hausdorff(X, subsample, homdim, coeff) |
||||||
|
distances.sort() |
||||||
|
confidence = np.sqrt(2) * 2 * distances[int(alpha*Nsubsamples)] |
||||||
|
return confidence |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def persistence(X, homdim=1, coeff=2, threshold=float('inf'), |
||||||
|
show_largest_homology=0, distance_matrix=False, Nsubsamples=0, |
||||||
|
alpha=0.05, cycle=None, save_path=None): |
||||||
|
""" |
||||||
|
Plot the persistence diagram of a dataset using ripser |
||||||
|
|
||||||
|
Also prints the five largest homology components |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
homdim : int, optional, default 1 |
||||||
|
The dimension of the homology |
||||||
|
coeff : int prime, optional, default 2 |
||||||
|
The coefficient basis |
||||||
|
threshold : float, optional, default infinity |
||||||
|
The maximum distance in the filtration |
||||||
|
show_largest_homology: int, optional, default 0 |
||||||
|
Print this many of the largest homology components |
||||||
|
distance_matrix : bool, optional, default False |
||||||
|
When true X will be interepreted as a distance matrix |
||||||
|
Nsubsamples : int, optional, default 0 |
||||||
|
The number of subsamples used in computing the confidence interval |
||||||
|
Does not compute the confidence interval when this is 0 |
||||||
|
alpha : float between 0 and 1, optional, default 0.05 |
||||||
|
1-alpha is the confidence |
||||||
|
cycle : int, optional, default None |
||||||
|
If given highlight the homology component in the plot corresponding to |
||||||
|
this cycle id |
||||||
|
save_path : str, optional, default None |
||||||
|
When given save the plot here |
||||||
|
""" |
||||||
|
result = ripser.ripser(X, maxdim=homdim, coeff=coeff, do_cocycles=True, |
||||||
|
distance_matrix=distance_matrix, thresh=threshold) |
||||||
|
diagrams = result['dgms'] |
||||||
|
plot_diagrams(diagrams, show=False) |
||||||
|
if (Nsubsamples>0): |
||||||
|
conf = confidence(X, alpha, Nsubsamples, homdim, 2) |
||||||
|
line_length = 10000 |
||||||
|
plt.plot([0, line_length], [conf, line_length + conf], color='green', |
||||||
|
linestyle='dashed',linewidth=2) |
||||||
|
if cycle is not None: |
||||||
|
dgm1 = diagrams[1] |
||||||
|
plt.scatter(dgm1[cycle, 0], dgm1[cycle, 1], 20, 'k', 'x') |
||||||
|
if save_path is not None: |
||||||
|
path = save_path + 'Z' + str(coeff) |
||||||
|
if (Nsubsamples>0): |
||||||
|
path += '_confidence' + str(1-alpha) |
||||||
|
path += '.png' |
||||||
|
plt.savefig(path) |
||||||
|
plt.show() |
||||||
|
|
||||||
|
if show_largest_homology != 0: |
||||||
|
dgm = diagrams[homdim] |
||||||
|
largest_indices = np.argsort(dgm[:, 0] - dgm[:, 1]) |
||||||
|
largest_components = dgm[largest_indices[:show_largest_homology]] |
||||||
|
print(f"Largest {homdim}-homology components:") |
||||||
|
print(largest_components) |
||||||
|
return |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def persistence_witness(X, number_of_landmarks=100, max_alpha_square=0.0, |
||||||
|
homdim=1): |
||||||
|
""" |
||||||
|
Plot the persistence diagram of a dataset using gudhi |
||||||
|
|
||||||
|
Uses a witness complex allowing it to be used on larger datasets |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
X: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
number_of_landmarks : int, optional, default 100 |
||||||
|
The number of landmarks in the witness complex |
||||||
|
max_alpha_square : double, optional, default 0.0 |
||||||
|
Maximal squared relaxation parameter |
||||||
|
homdim : int, optional, default 1 |
||||||
|
The dimension of the homology |
||||||
|
""" |
||||||
|
print("Sampling landmarks...", end=" ") |
||||||
|
|
||||||
|
witnesses = X.to_numpy() |
||||||
|
landmarks = gudhi.pick_n_random_points( |
||||||
|
points=witnesses, nb_points=number_of_landmarks |
||||||
|
) |
||||||
|
print("done") |
||||||
|
message = ( |
||||||
|
"EuclideanStrongWitnessComplex with max_edge_length=" |
||||||
|
+ repr(max_alpha_square) |
||||||
|
+ " - Number of landmarks=" |
||||||
|
+ repr(number_of_landmarks) |
||||||
|
) |
||||||
|
print(message) |
||||||
|
witness_complex = gudhi.EuclideanStrongWitnessComplex( |
||||||
|
witnesses=witnesses, landmarks=landmarks |
||||||
|
) |
||||||
|
simplex_tree = witness_complex.create_simplex_tree( |
||||||
|
max_alpha_square=max_alpha_square, |
||||||
|
limit_dimension=homdim |
||||||
|
) |
||||||
|
message = "Number of simplices=" + repr(simplex_tree.num_simplices()) |
||||||
|
print(message) |
||||||
|
diag = simplex_tree.persistence() |
||||||
|
print("betti_numbers()=") |
||||||
|
print(simplex_tree.betti_numbers()) |
||||||
|
gudhi.plot_persistence_diagram(diag, band=0.0) |
||||||
|
plt.show() |
||||||
|
return |
@ -0,0 +1,343 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
""" |
||||||
|
Some plotting functions related to grating stimuli responses |
||||||
|
""" |
||||||
|
import numpy as np |
||||||
|
import pandas as pd |
||||||
|
import math |
||||||
|
import random |
||||||
|
|
||||||
|
import plotly.graph_objects as go |
||||||
|
from plotly.offline import plot |
||||||
|
from plotly.subplots import make_subplots |
||||||
|
import matplotlib.pyplot as plt |
||||||
|
import imageio |
||||||
|
|
||||||
|
import sklearn.manifold as manifold |
||||||
|
from sklearn.decomposition import PCA |
||||||
|
|
||||||
|
from tqdm import trange |
||||||
|
|
||||||
|
from decorators import multi_input |
||||||
|
from gratings import angular_mean, grating_image |
||||||
|
|
||||||
|
|
||||||
|
@multi_input |
||||||
|
def plot_data(data, transformation="None", labels=None, colors=None, |
||||||
|
save_path=None): |
||||||
|
""" |
||||||
|
Plot data colored by its indices and additional provided labels |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
data: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
transformation: str, optional, default "None" |
||||||
|
The type of dimension reduction used |
||||||
|
Choose from "None", PCA" or "SpectralEmbedding" |
||||||
|
labels: dataframe(n_datapoints, n_labels), optional, default None |
||||||
|
Dataframe containing additional labels to be plotted |
||||||
|
colors: list of str, optional, default None |
||||||
|
A list containing the color scales used for each label |
||||||
|
When None use "Viridis" for all labels |
||||||
|
save_path: str, optional, default None |
||||||
|
When given save the figure here |
||||||
|
|
||||||
|
Raises |
||||||
|
------ |
||||||
|
ValueError |
||||||
|
When an invalid value for "transformation" is given |
||||||
|
""" |
||||||
|
# Set labels |
||||||
|
indices = data.index |
||||||
|
plotted_labels = indices.to_frame() |
||||||
|
if labels is not None: |
||||||
|
plotted_labels = plotted_labels.join(labels) |
||||||
|
if colors is None: |
||||||
|
colors = ["Viridis"]*len(plotted_labels.columns) |
||||||
|
Nlabels = len(plotted_labels.columns) |
||||||
|
|
||||||
|
# Transform data to lower dimension |
||||||
|
if (transformation == "PCA"): |
||||||
|
pca = PCA(n_components=3) |
||||||
|
pca.fit(data) |
||||||
|
data = pca.transform(data) |
||||||
|
elif (transformation == "SpectralEmbedding"): |
||||||
|
embedding = manifold.SpectralEmbedding(n_components=3, affinity='rbf') |
||||||
|
data = embedding.fit_transform(data) |
||||||
|
elif (transformation == "None"): |
||||||
|
pass |
||||||
|
else: |
||||||
|
raise ValueError("Invalid plot transformation") |
||||||
|
|
||||||
|
# Plot |
||||||
|
data = pd.DataFrame(data) |
||||||
|
data.index = indices |
||||||
|
fig = go.Figure() |
||||||
|
fig = make_subplots(rows=2, |
||||||
|
cols=math.ceil(Nlabels/2), |
||||||
|
specs=[[{'type': 'scene'}]*math.ceil(Nlabels/2)]*2) |
||||||
|
for i,label in enumerate(plotted_labels): |
||||||
|
fig.add_trace( |
||||||
|
go.Scatter3d( |
||||||
|
mode='markers', |
||||||
|
name=label, |
||||||
|
x=data[0], |
||||||
|
y=data[1], |
||||||
|
z=data[2], |
||||||
|
text = plotted_labels[label], |
||||||
|
hoverinfo = ['x','text'], |
||||||
|
marker=dict( |
||||||
|
color=plotted_labels[label], |
||||||
|
size=5, |
||||||
|
sizemode='diameter', |
||||||
|
colorscale=colors[i] |
||||||
|
) |
||||||
|
), |
||||||
|
row=i%2 + 1, col=math.floor(i/2)+1 |
||||||
|
) |
||||||
|
fig.update_layout(height=900, width=1600, title_text="") |
||||||
|
if save_path is None: |
||||||
|
plot(fig) |
||||||
|
fig.show() |
||||||
|
else: |
||||||
|
path = save_path |
||||||
|
path += 'plot' |
||||||
|
if transformation != 'None': |
||||||
|
path += '_' + transformation |
||||||
|
path += ".html" |
||||||
|
fig.write_html(path) |
||||||
|
return |
||||||
|
|
||||||
|
def plot_connections(data_points, connections, threshold=0.1, opacity=0.1, |
||||||
|
save_path=None): |
||||||
|
#draw a square |
||||||
|
x = data_points[:,0] |
||||||
|
y = data_points[:,1] |
||||||
|
z = data_points[:,2] |
||||||
|
N = len(x) |
||||||
|
|
||||||
|
#the start and end point for each line |
||||||
|
pairs = [(i,j) for i,j in np.ndindex((N,N)) if connections[i,j] > threshold] |
||||||
|
|
||||||
|
trace1 = go.Scatter3d( |
||||||
|
x=x, |
||||||
|
y=y, |
||||||
|
z=z, |
||||||
|
mode='markers', |
||||||
|
name='markers' |
||||||
|
) |
||||||
|
|
||||||
|
x_lines = list() |
||||||
|
y_lines = list() |
||||||
|
z_lines = list() |
||||||
|
|
||||||
|
#create the coordinate list for the lines |
||||||
|
for p in pairs: |
||||||
|
for i in range(2): |
||||||
|
x_lines.append(x[p[i]]) |
||||||
|
y_lines.append(y[p[i]]) |
||||||
|
z_lines.append(z[p[i]]) |
||||||
|
x_lines.append(None) |
||||||
|
y_lines.append(None) |
||||||
|
z_lines.append(None) |
||||||
|
|
||||||
|
trace2 = go.Scatter3d( |
||||||
|
x=x_lines, |
||||||
|
y=y_lines, |
||||||
|
z=z_lines, |
||||||
|
opacity=opacity, |
||||||
|
mode='lines', |
||||||
|
name='lines' |
||||||
|
) |
||||||
|
|
||||||
|
fig = go.Figure(data=[trace1,trace2]) |
||||||
|
|
||||||
|
if save_path is not None: |
||||||
|
path = save_path |
||||||
|
path += '/plot_glue.html' |
||||||
|
fig.write_html(path) |
||||||
|
else: |
||||||
|
plot(fig) |
||||||
|
return |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def plot_mean_against_index(data, value, index, circ=True, save_path=None): |
||||||
|
""" |
||||||
|
Plot the mean value against an index |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
data: dataframe(n_datapoints, n_features): |
||||||
|
Dataframe containing the data |
||||||
|
value : dataframe(n_datapoints) |
||||||
|
The value we average over |
||||||
|
index : str |
||||||
|
The name of the index we plot against |
||||||
|
circ : bool, optional, default True |
||||||
|
Whether or not the index is angular |
||||||
|
""" |
||||||
|
unique_index = data.reset_index()[index].unique() |
||||||
|
if len(unique_index) > 1: |
||||||
|
if circ: |
||||||
|
means = value.groupby([index]).agg(lambda X: angular_mean(X, period=1)) |
||||||
|
else: |
||||||
|
means = value.groupby([index]).mean() |
||||||
|
plt.scatter(means.index, means) |
||||||
|
if circ: |
||||||
|
plt.ylim(0, 1) |
||||||
|
plt.xlabel(index) |
||||||
|
plt.ylabel('mean') |
||||||
|
plt.show() |
||||||
|
return |
||||||
|
|
||||||
|
@multi_input |
||||||
|
def show_feature(decoding, Nimages=10, Npixels=100, normalized=True, |
||||||
|
intervals="equal_images", save_path=None): |
||||||
|
""" |
||||||
|
Show how the gratings depend on a decoded parameter |
||||||
|
|
||||||
|
Shows the average grating for different values of the decoding and plot the |
||||||
|
grating parameters as a function of the decoding |
||||||
|
|
||||||
|
Parameters |
||||||
|
---------- |
||||||
|
decoding : dataframe(n_datapoints) |
||||||
|
A dataframe containing the decoded value for each data point |
||||||
|
labeled by indices "orientation, "frequency", "phase" and "contrast" |
||||||
|
Nimages : int, optional, default 10 |
||||||
|
Number of different images shown |
||||||
|
Npixels: int, optional, default 100 |
||||||
|
The number of pixels in each direction |
||||||
|
normalized : bool, optional, default True |
||||||
|
If true normalize the average images |
||||||
|
intervals : str, optional, default "equal_images" |
||||||
|
How the images are binned together |
||||||
|
- When "equal_images" average such that each image is an average of |
||||||
|
an equal number of images |
||||||
|
- When "equal_decoding" average such that each image is averaged from |
||||||
|
images within an equal fraction of the decoding |
||||||
|
save_path: str, optional, default None |
||||||
|
If given, save a gif here |
||||||
|
|
||||||
|
Raises |
||||||
|
------ |
||||||
|
ValueError |
||||||
|
When an invalid value for "intervals" is given |
||||||
|
|
||||||
|
Warns |
||||||
|
----- |
||||||
|
When some of the bins are empty |
||||||
|
""" |
||||||
|
try: |
||||||
|
orientation = decoding.reset_index()["orientation"] |
||||||
|
except KeyError: |
||||||
|
orientation = pd.Series(np.ones(len(decoding))) |
||||||
|
try: |
||||||
|
contrast = decoding.reset_index()["contrast"] |
||||||
|
except KeyError: |
||||||
|
contrast = pd.Series(np.ones(len(decoding))) |
||||||
|
try: |
||||||
|
frequency = decoding.reset_index()["frequency"] |
||||||
|
except KeyError: |
||||||
|
frequency = pd.Series(np.ones(len(decoding))) |
||||||
|
try: |
||||||
|
phase = decoding.reset_index()["phase"] |
||||||
|
except KeyError: |
||||||
|
phase = pd.Series(np.ones(len(decoding))) |
||||||
|
|
||||||
|
decoding = decoding.to_numpy().ravel() |
||||||
|
N = decoding.shape[0] |
||||||
|
|
||||||
|
# Convert grating labels |
||||||
|
if (max(phase) > 2*np.pi): |
||||||
|
orientation = orientation * np.pi/180 |
||||||
|
frequency = frequency * 30 |
||||||
|
phase = phase * np.pi/180 |
||||||
|
|
||||||
|
# Assign images to intervals |
||||||
|
interval_labels = np.zeros(N, dtype=int) |
||||||
|
count = np.zeros(Nimages) |
||||||
|
if intervals=="equal_decoding": |
||||||
|
intervals = np.linspace(min(decoding), max(decoding), Nimages+1) |
||||||
|
for i in range(Nimages): |
||||||
|
for j in range(N): |
||||||
|
if (intervals[i] <= decoding[j]) and (decoding[j] <= intervals[i+1]): |
||||||
|
interval_labels[j] = i |
||||||
|
count[i] += 1 |
||||||
|
elif intervals=="equal_images": |
||||||
|
interval_labels = (np.floor(np.linspace(0,Nimages-0.01,N))).astype(int) |
||||||
|
interval_labels = interval_labels[np.argsort(decoding)] |
||||||
|
for i in range(Nimages): |
||||||
|
count[i] = (interval_labels[interval_labels==i]).shape[0] |
||||||
|
else: |
||||||
|
raise ValueError("Invalid intervals type") |
||||||
|
return |
||||||
|
|
||||||
|
# Average over grating images |
||||||
|
av_images = np.zeros([Nimages, Npixels, Npixels]) |
||||||
|
grouped_indices = [[] for _ in range(Nimages)] |
||||||
|
iterator = trange(0, N, position=0, leave=True) |
||||||
|
iterator.set_description("Averaging images") |
||||||
|
for j in iterator: |
||||||
|
pars = np.array([orientation[j], frequency[j], phase[j], contrast[j]]) |
||||||
|
grouped_indices[interval_labels[j]].append(pars) |
||||||
|
if random.choice([True, False]): |
||||||
|
av_images[interval_labels[j]] += ( |
||||||
|
grating_image(pars, N=Npixels, plot=False) |
||||||
|
) |
||||||
|
for i in range(Nimages): |
||||||
|
if count[i] != 0: |
||||||
|
av_images[i] = av_images[i]/count[i] |
||||||
|
|
||||||
|
print("Number of images averaged over:") |
||||||
|
print(count) |
||||||
|
|
||||||
|
if normalized: |
||||||
|
av_images = av_images/np.max(np.abs(av_images)) |
||||||
|
|
||||||
|
# Show averages and make gif |
||||||
|
for i in range(Nimages): |
||||||
|
if (not math.isnan(av_images[i,0,0])): |
||||||
|
plt.imshow(av_images[i], "gray", vmin = -1, vmax = 1) |
||||||
|
if save_path is not None: |
||||||
|
plt.savefig(save_path + "_image" + str(i+1) + ".png") |
||||||
|
plt.show() |
||||||
|
if save_path is not None: |
||||||
|
frames = (255*0.5*(1+av_images)).astype(np.uint8) |
||||||
|
imageio.mimsave(save_path + ".gif", frames) |
||||||
|
|
||||||
|
# Plot the grating parameters against the decoding |
||||||
|
try: |
||||||
|
ori_fun = lambda x : angular_mean(np.array(x)[:,0],period=np.pi,axis=0) |
||||||
|
freq_fun = lambda x: np.mean(np.array(x)[:,1],axis=0) |
||||||
|
phase_fun = lambda x: angular_mean(np.array(x)[:,2],period=2*np.pi,axis=0) |
||||||
|
contrast_fun = lambda x: np.mean(np.array(x)[:,3],axis=0) |
||||||
|
av_orientation = np.array(list(map(ori_fun, grouped_indices))) |
||||||
|
av_frequency = np.array(list(map(freq_fun, grouped_indices))) |
||||||
|
av_phase = np.array(list(map(phase_fun, grouped_indices))) |
||||||
|
av_contrast = np.array(list(map(contrast_fun, grouped_indices))) |
||||||
|
except IndexError: |
||||||
|
print("Error: some bins are empty; choose a smaller bin count.") |
||||||
|
else: |
||||||
|
if len(set(orientation)) > 1: |
||||||
|
plt.scatter(np.arange(Nimages), av_orientation) |
||||||
|
plt.xlabel("Decoding") |
||||||
|
plt.ylabel("Orientation") |
||||||
|
plt.show() |
||||||
|
if len(set(frequency)) > 1: |
||||||
|
plt.scatter(np.arange(Nimages), av_frequency) |
||||||
|
plt.xlabel("Decoding") |
||||||
|
plt.ylabel("Frequency") |
||||||
|
plt.show() |
||||||
|
if len(set(phase)) > 1: |
||||||
|
plt.scatter(np.arange(Nimages), av_phase) |
||||||
|
plt.xlabel("Decoding") |
||||||
|
plt.ylabel("Phase") |
||||||
|
plt.show() |
||||||
|
if len(set(contrast)) > 1: |
||||||
|
plt.scatter(np.arange(Nimages), av_contrast) |
||||||
|
plt.xlabel("Decoding") |
||||||
|
plt.ylabel("Contrast") |
||||||
|
plt.show() |
||||||
|
return |
Loading…
Reference in new issue