You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
11 KiB
344 lines
11 KiB
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
Some plotting functions related to grating stimuli responses
|
||
|
"""
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import math
|
||
|
import random
|
||
|
|
||
|
import plotly.graph_objects as go
|
||
|
from plotly.offline import plot
|
||
|
from plotly.subplots import make_subplots
|
||
|
import matplotlib.pyplot as plt
|
||
|
import imageio
|
||
|
|
||
|
import sklearn.manifold as manifold
|
||
|
from sklearn.decomposition import PCA
|
||
|
|
||
|
from tqdm import trange
|
||
|
|
||
|
from decorators import multi_input
|
||
|
from gratings import angular_mean, grating_image
|
||
|
|
||
|
|
||
|
@multi_input
|
||
|
def plot_data(data, transformation="None", labels=None, colors=None,
|
||
|
save_path=None):
|
||
|
"""
|
||
|
Plot data colored by its indices and additional provided labels
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
data: dataframe(n_datapoints, n_features):
|
||
|
Dataframe containing the data
|
||
|
transformation: str, optional, default "None"
|
||
|
The type of dimension reduction used
|
||
|
Choose from "None", PCA" or "SpectralEmbedding"
|
||
|
labels: dataframe(n_datapoints, n_labels), optional, default None
|
||
|
Dataframe containing additional labels to be plotted
|
||
|
colors: list of str, optional, default None
|
||
|
A list containing the color scales used for each label
|
||
|
When None use "Viridis" for all labels
|
||
|
save_path: str, optional, default None
|
||
|
When given save the figure here
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
ValueError
|
||
|
When an invalid value for "transformation" is given
|
||
|
"""
|
||
|
# Set labels
|
||
|
indices = data.index
|
||
|
plotted_labels = indices.to_frame()
|
||
|
if labels is not None:
|
||
|
plotted_labels = plotted_labels.join(labels)
|
||
|
if colors is None:
|
||
|
colors = ["Viridis"]*len(plotted_labels.columns)
|
||
|
Nlabels = len(plotted_labels.columns)
|
||
|
|
||
|
# Transform data to lower dimension
|
||
|
if (transformation == "PCA"):
|
||
|
pca = PCA(n_components=3)
|
||
|
pca.fit(data)
|
||
|
data = pca.transform(data)
|
||
|
elif (transformation == "SpectralEmbedding"):
|
||
|
embedding = manifold.SpectralEmbedding(n_components=3, affinity='rbf')
|
||
|
data = embedding.fit_transform(data)
|
||
|
elif (transformation == "None"):
|
||
|
pass
|
||
|
else:
|
||
|
raise ValueError("Invalid plot transformation")
|
||
|
|
||
|
# Plot
|
||
|
data = pd.DataFrame(data)
|
||
|
data.index = indices
|
||
|
fig = go.Figure()
|
||
|
fig = make_subplots(rows=2,
|
||
|
cols=math.ceil(Nlabels/2),
|
||
|
specs=[[{'type': 'scene'}]*math.ceil(Nlabels/2)]*2)
|
||
|
for i,label in enumerate(plotted_labels):
|
||
|
fig.add_trace(
|
||
|
go.Scatter3d(
|
||
|
mode='markers',
|
||
|
name=label,
|
||
|
x=data[0],
|
||
|
y=data[1],
|
||
|
z=data[2],
|
||
|
text = plotted_labels[label],
|
||
|
hoverinfo = ['x','text'],
|
||
|
marker=dict(
|
||
|
color=plotted_labels[label],
|
||
|
size=5,
|
||
|
sizemode='diameter',
|
||
|
colorscale=colors[i]
|
||
|
)
|
||
|
),
|
||
|
row=i%2 + 1, col=math.floor(i/2)+1
|
||
|
)
|
||
|
fig.update_layout(height=900, width=1600, title_text="")
|
||
|
if save_path is None:
|
||
|
plot(fig)
|
||
|
fig.show()
|
||
|
else:
|
||
|
path = save_path
|
||
|
path += 'plot'
|
||
|
if transformation != 'None':
|
||
|
path += '_' + transformation
|
||
|
path += ".html"
|
||
|
fig.write_html(path)
|
||
|
return
|
||
|
|
||
|
def plot_connections(data_points, connections, threshold=0.1, opacity=0.1,
|
||
|
save_path=None):
|
||
|
#draw a square
|
||
|
x = data_points[:,0]
|
||
|
y = data_points[:,1]
|
||
|
z = data_points[:,2]
|
||
|
N = len(x)
|
||
|
|
||
|
#the start and end point for each line
|
||
|
pairs = [(i,j) for i,j in np.ndindex((N,N)) if connections[i,j] > threshold]
|
||
|
|
||
|
trace1 = go.Scatter3d(
|
||
|
x=x,
|
||
|
y=y,
|
||
|
z=z,
|
||
|
mode='markers',
|
||
|
name='markers'
|
||
|
)
|
||
|
|
||
|
x_lines = list()
|
||
|
y_lines = list()
|
||
|
z_lines = list()
|
||
|
|
||
|
#create the coordinate list for the lines
|
||
|
for p in pairs:
|
||
|
for i in range(2):
|
||
|
x_lines.append(x[p[i]])
|
||
|
y_lines.append(y[p[i]])
|
||
|
z_lines.append(z[p[i]])
|
||
|
x_lines.append(None)
|
||
|
y_lines.append(None)
|
||
|
z_lines.append(None)
|
||
|
|
||
|
trace2 = go.Scatter3d(
|
||
|
x=x_lines,
|
||
|
y=y_lines,
|
||
|
z=z_lines,
|
||
|
opacity=opacity,
|
||
|
mode='lines',
|
||
|
name='lines'
|
||
|
)
|
||
|
|
||
|
fig = go.Figure(data=[trace1,trace2])
|
||
|
|
||
|
if save_path is not None:
|
||
|
path = save_path
|
||
|
path += '/plot_glue.html'
|
||
|
fig.write_html(path)
|
||
|
else:
|
||
|
plot(fig)
|
||
|
return
|
||
|
|
||
|
@multi_input
|
||
|
def plot_mean_against_index(data, value, index, circ=True, save_path=None):
|
||
|
"""
|
||
|
Plot the mean value against an index
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
data: dataframe(n_datapoints, n_features):
|
||
|
Dataframe containing the data
|
||
|
value : dataframe(n_datapoints)
|
||
|
The value we average over
|
||
|
index : str
|
||
|
The name of the index we plot against
|
||
|
circ : bool, optional, default True
|
||
|
Whether or not the index is angular
|
||
|
"""
|
||
|
unique_index = data.reset_index()[index].unique()
|
||
|
if len(unique_index) > 1:
|
||
|
if circ:
|
||
|
means = value.groupby([index]).agg(lambda X: angular_mean(X, period=1))
|
||
|
else:
|
||
|
means = value.groupby([index]).mean()
|
||
|
plt.scatter(means.index, means)
|
||
|
if circ:
|
||
|
plt.ylim(0, 1)
|
||
|
plt.xlabel(index)
|
||
|
plt.ylabel('mean')
|
||
|
plt.show()
|
||
|
return
|
||
|
|
||
|
@multi_input
|
||
|
def show_feature(decoding, Nimages=10, Npixels=100, normalized=True,
|
||
|
intervals="equal_images", save_path=None):
|
||
|
"""
|
||
|
Show how the gratings depend on a decoded parameter
|
||
|
|
||
|
Shows the average grating for different values of the decoding and plot the
|
||
|
grating parameters as a function of the decoding
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
decoding : dataframe(n_datapoints)
|
||
|
A dataframe containing the decoded value for each data point
|
||
|
labeled by indices "orientation, "frequency", "phase" and "contrast"
|
||
|
Nimages : int, optional, default 10
|
||
|
Number of different images shown
|
||
|
Npixels: int, optional, default 100
|
||
|
The number of pixels in each direction
|
||
|
normalized : bool, optional, default True
|
||
|
If true normalize the average images
|
||
|
intervals : str, optional, default "equal_images"
|
||
|
How the images are binned together
|
||
|
- When "equal_images" average such that each image is an average of
|
||
|
an equal number of images
|
||
|
- When "equal_decoding" average such that each image is averaged from
|
||
|
images within an equal fraction of the decoding
|
||
|
save_path: str, optional, default None
|
||
|
If given, save a gif here
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
ValueError
|
||
|
When an invalid value for "intervals" is given
|
||
|
|
||
|
Warns
|
||
|
-----
|
||
|
When some of the bins are empty
|
||
|
"""
|
||
|
try:
|
||
|
orientation = decoding.reset_index()["orientation"]
|
||
|
except KeyError:
|
||
|
orientation = pd.Series(np.ones(len(decoding)))
|
||
|
try:
|
||
|
contrast = decoding.reset_index()["contrast"]
|
||
|
except KeyError:
|
||
|
contrast = pd.Series(np.ones(len(decoding)))
|
||
|
try:
|
||
|
frequency = decoding.reset_index()["frequency"]
|
||
|
except KeyError:
|
||
|
frequency = pd.Series(np.ones(len(decoding)))
|
||
|
try:
|
||
|
phase = decoding.reset_index()["phase"]
|
||
|
except KeyError:
|
||
|
phase = pd.Series(np.ones(len(decoding)))
|
||
|
|
||
|
decoding = decoding.to_numpy().ravel()
|
||
|
N = decoding.shape[0]
|
||
|
|
||
|
# Convert grating labels
|
||
|
if (max(phase) > 2*np.pi):
|
||
|
orientation = orientation * np.pi/180
|
||
|
frequency = frequency * 30
|
||
|
phase = phase * np.pi/180
|
||
|
|
||
|
# Assign images to intervals
|
||
|
interval_labels = np.zeros(N, dtype=int)
|
||
|
count = np.zeros(Nimages)
|
||
|
if intervals=="equal_decoding":
|
||
|
intervals = np.linspace(min(decoding), max(decoding), Nimages+1)
|
||
|
for i in range(Nimages):
|
||
|
for j in range(N):
|
||
|
if (intervals[i] <= decoding[j]) and (decoding[j] <= intervals[i+1]):
|
||
|
interval_labels[j] = i
|
||
|
count[i] += 1
|
||
|
elif intervals=="equal_images":
|
||
|
interval_labels = (np.floor(np.linspace(0,Nimages-0.01,N))).astype(int)
|
||
|
interval_labels = interval_labels[np.argsort(decoding)]
|
||
|
for i in range(Nimages):
|
||
|
count[i] = (interval_labels[interval_labels==i]).shape[0]
|
||
|
else:
|
||
|
raise ValueError("Invalid intervals type")
|
||
|
return
|
||
|
|
||
|
# Average over grating images
|
||
|
av_images = np.zeros([Nimages, Npixels, Npixels])
|
||
|
grouped_indices = [[] for _ in range(Nimages)]
|
||
|
iterator = trange(0, N, position=0, leave=True)
|
||
|
iterator.set_description("Averaging images")
|
||
|
for j in iterator:
|
||
|
pars = np.array([orientation[j], frequency[j], phase[j], contrast[j]])
|
||
|
grouped_indices[interval_labels[j]].append(pars)
|
||
|
if random.choice([True, False]):
|
||
|
av_images[interval_labels[j]] += (
|
||
|
grating_image(pars, N=Npixels, plot=False)
|
||
|
)
|
||
|
for i in range(Nimages):
|
||
|
if count[i] != 0:
|
||
|
av_images[i] = av_images[i]/count[i]
|
||
|
|
||
|
print("Number of images averaged over:")
|
||
|
print(count)
|
||
|
|
||
|
if normalized:
|
||
|
av_images = av_images/np.max(np.abs(av_images))
|
||
|
|
||
|
# Show averages and make gif
|
||
|
for i in range(Nimages):
|
||
|
if (not math.isnan(av_images[i,0,0])):
|
||
|
plt.imshow(av_images[i], "gray", vmin = -1, vmax = 1)
|
||
|
if save_path is not None:
|
||
|
plt.savefig(save_path + "_image" + str(i+1) + ".png")
|
||
|
plt.show()
|
||
|
if save_path is not None:
|
||
|
frames = (255*0.5*(1+av_images)).astype(np.uint8)
|
||
|
imageio.mimsave(save_path + ".gif", frames)
|
||
|
|
||
|
# Plot the grating parameters against the decoding
|
||
|
try:
|
||
|
ori_fun = lambda x : angular_mean(np.array(x)[:,0],period=np.pi,axis=0)
|
||
|
freq_fun = lambda x: np.mean(np.array(x)[:,1],axis=0)
|
||
|
phase_fun = lambda x: angular_mean(np.array(x)[:,2],period=2*np.pi,axis=0)
|
||
|
contrast_fun = lambda x: np.mean(np.array(x)[:,3],axis=0)
|
||
|
av_orientation = np.array(list(map(ori_fun, grouped_indices)))
|
||
|
av_frequency = np.array(list(map(freq_fun, grouped_indices)))
|
||
|
av_phase = np.array(list(map(phase_fun, grouped_indices)))
|
||
|
av_contrast = np.array(list(map(contrast_fun, grouped_indices)))
|
||
|
except IndexError:
|
||
|
print("Error: some bins are empty; choose a smaller bin count.")
|
||
|
else:
|
||
|
if len(set(orientation)) > 1:
|
||
|
plt.scatter(np.arange(Nimages), av_orientation)
|
||
|
plt.xlabel("Decoding")
|
||
|
plt.ylabel("Orientation")
|
||
|
plt.show()
|
||
|
if len(set(frequency)) > 1:
|
||
|
plt.scatter(np.arange(Nimages), av_frequency)
|
||
|
plt.xlabel("Decoding")
|
||
|
plt.ylabel("Frequency")
|
||
|
plt.show()
|
||
|
if len(set(phase)) > 1:
|
||
|
plt.scatter(np.arange(Nimages), av_phase)
|
||
|
plt.xlabel("Decoding")
|
||
|
plt.ylabel("Phase")
|
||
|
plt.show()
|
||
|
if len(set(contrast)) > 1:
|
||
|
plt.scatter(np.arange(Nimages), av_contrast)
|
||
|
plt.xlabel("Decoding")
|
||
|
plt.ylabel("Contrast")
|
||
|
plt.show()
|
||
|
return
|