You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
343 lines
11 KiB
343 lines
11 KiB
# -*- coding: utf-8 -*- |
|
""" |
|
Some plotting functions related to grating stimuli responses |
|
""" |
|
import numpy as np |
|
import pandas as pd |
|
import math |
|
import random |
|
|
|
import plotly.graph_objects as go |
|
from plotly.offline import plot |
|
from plotly.subplots import make_subplots |
|
import matplotlib.pyplot as plt |
|
import imageio |
|
|
|
import sklearn.manifold as manifold |
|
from sklearn.decomposition import PCA |
|
|
|
from tqdm import trange |
|
|
|
from decorators import multi_input |
|
from gratings import angular_mean, grating_image |
|
|
|
|
|
@multi_input |
|
def plot_data(data, transformation="None", labels=None, colors=None, |
|
save_path=None): |
|
""" |
|
Plot data colored by its indices and additional provided labels |
|
|
|
Parameters |
|
---------- |
|
data: dataframe(n_datapoints, n_features): |
|
Dataframe containing the data |
|
transformation: str, optional, default "None" |
|
The type of dimension reduction used |
|
Choose from "None", PCA" or "SpectralEmbedding" |
|
labels: dataframe(n_datapoints, n_labels), optional, default None |
|
Dataframe containing additional labels to be plotted |
|
colors: list of str, optional, default None |
|
A list containing the color scales used for each label |
|
When None use "Viridis" for all labels |
|
save_path: str, optional, default None |
|
When given save the figure here |
|
|
|
Raises |
|
------ |
|
ValueError |
|
When an invalid value for "transformation" is given |
|
""" |
|
# Set labels |
|
indices = data.index |
|
plotted_labels = indices.to_frame() |
|
if labels is not None: |
|
plotted_labels = plotted_labels.join(labels) |
|
if colors is None: |
|
colors = ["Viridis"]*len(plotted_labels.columns) |
|
Nlabels = len(plotted_labels.columns) |
|
|
|
# Transform data to lower dimension |
|
if (transformation == "PCA"): |
|
pca = PCA(n_components=3) |
|
pca.fit(data) |
|
data = pca.transform(data) |
|
elif (transformation == "SpectralEmbedding"): |
|
embedding = manifold.SpectralEmbedding(n_components=3, affinity='rbf') |
|
data = embedding.fit_transform(data) |
|
elif (transformation == "None"): |
|
pass |
|
else: |
|
raise ValueError("Invalid plot transformation") |
|
|
|
# Plot |
|
data = pd.DataFrame(data) |
|
data.index = indices |
|
fig = go.Figure() |
|
fig = make_subplots(rows=2, |
|
cols=math.ceil(Nlabels/2), |
|
specs=[[{'type': 'scene'}]*math.ceil(Nlabels/2)]*2) |
|
for i,label in enumerate(plotted_labels): |
|
fig.add_trace( |
|
go.Scatter3d( |
|
mode='markers', |
|
name=label, |
|
x=data[0], |
|
y=data[1], |
|
z=data[2], |
|
text = plotted_labels[label], |
|
hoverinfo = ['x','text'], |
|
marker=dict( |
|
color=plotted_labels[label], |
|
size=5, |
|
sizemode='diameter', |
|
colorscale=colors[i] |
|
) |
|
), |
|
row=i%2 + 1, col=math.floor(i/2)+1 |
|
) |
|
fig.update_layout(height=900, width=1600, title_text="") |
|
if save_path is None: |
|
plot(fig) |
|
fig.show() |
|
else: |
|
path = save_path |
|
path += 'plot' |
|
if transformation != 'None': |
|
path += '_' + transformation |
|
path += ".html" |
|
fig.write_html(path) |
|
return |
|
|
|
def plot_connections(data_points, connections, threshold=0.1, opacity=0.1, |
|
save_path=None): |
|
#draw a square |
|
x = data_points[:,0] |
|
y = data_points[:,1] |
|
z = data_points[:,2] |
|
N = len(x) |
|
|
|
#the start and end point for each line |
|
pairs = [(i,j) for i,j in np.ndindex((N,N)) if connections[i,j] > threshold] |
|
|
|
trace1 = go.Scatter3d( |
|
x=x, |
|
y=y, |
|
z=z, |
|
mode='markers', |
|
name='markers' |
|
) |
|
|
|
x_lines = list() |
|
y_lines = list() |
|
z_lines = list() |
|
|
|
#create the coordinate list for the lines |
|
for p in pairs: |
|
for i in range(2): |
|
x_lines.append(x[p[i]]) |
|
y_lines.append(y[p[i]]) |
|
z_lines.append(z[p[i]]) |
|
x_lines.append(None) |
|
y_lines.append(None) |
|
z_lines.append(None) |
|
|
|
trace2 = go.Scatter3d( |
|
x=x_lines, |
|
y=y_lines, |
|
z=z_lines, |
|
opacity=opacity, |
|
mode='lines', |
|
name='lines' |
|
) |
|
|
|
fig = go.Figure(data=[trace1,trace2]) |
|
|
|
if save_path is not None: |
|
path = save_path |
|
path += '/plot_glue.html' |
|
fig.write_html(path) |
|
else: |
|
plot(fig) |
|
return |
|
|
|
@multi_input |
|
def plot_mean_against_index(data, value, index, circ=True, save_path=None): |
|
""" |
|
Plot the mean value against an index |
|
|
|
Parameters |
|
---------- |
|
data: dataframe(n_datapoints, n_features): |
|
Dataframe containing the data |
|
value : dataframe(n_datapoints) |
|
The value we average over |
|
index : str |
|
The name of the index we plot against |
|
circ : bool, optional, default True |
|
Whether or not the index is angular |
|
""" |
|
unique_index = data.reset_index()[index].unique() |
|
if len(unique_index) > 1: |
|
if circ: |
|
means = value.groupby([index]).agg(lambda X: angular_mean(X, period=1)) |
|
else: |
|
means = value.groupby([index]).mean() |
|
plt.scatter(means.index, means) |
|
if circ: |
|
plt.ylim(0, 1) |
|
plt.xlabel(index) |
|
plt.ylabel('mean') |
|
plt.show() |
|
return |
|
|
|
@multi_input |
|
def show_feature(decoding, Nimages=10, Npixels=100, normalized=True, |
|
intervals="equal_images", save_path=None): |
|
""" |
|
Show how the gratings depend on a decoded parameter |
|
|
|
Shows the average grating for different values of the decoding and plot the |
|
grating parameters as a function of the decoding |
|
|
|
Parameters |
|
---------- |
|
decoding : dataframe(n_datapoints) |
|
A dataframe containing the decoded value for each data point |
|
labeled by indices "orientation, "frequency", "phase" and "contrast" |
|
Nimages : int, optional, default 10 |
|
Number of different images shown |
|
Npixels: int, optional, default 100 |
|
The number of pixels in each direction |
|
normalized : bool, optional, default True |
|
If true normalize the average images |
|
intervals : str, optional, default "equal_images" |
|
How the images are binned together |
|
- When "equal_images" average such that each image is an average of |
|
an equal number of images |
|
- When "equal_decoding" average such that each image is averaged from |
|
images within an equal fraction of the decoding |
|
save_path: str, optional, default None |
|
If given, save a gif here |
|
|
|
Raises |
|
------ |
|
ValueError |
|
When an invalid value for "intervals" is given |
|
|
|
Warns |
|
----- |
|
When some of the bins are empty |
|
""" |
|
try: |
|
orientation = decoding.reset_index()["orientation"] |
|
except KeyError: |
|
orientation = pd.Series(np.ones(len(decoding))) |
|
try: |
|
contrast = decoding.reset_index()["contrast"] |
|
except KeyError: |
|
contrast = pd.Series(np.ones(len(decoding))) |
|
try: |
|
frequency = decoding.reset_index()["frequency"] |
|
except KeyError: |
|
frequency = pd.Series(np.ones(len(decoding))) |
|
try: |
|
phase = decoding.reset_index()["phase"] |
|
except KeyError: |
|
phase = pd.Series(np.ones(len(decoding))) |
|
|
|
decoding = decoding.to_numpy().ravel() |
|
N = decoding.shape[0] |
|
|
|
# Convert grating labels |
|
if (max(phase) > 2*np.pi): |
|
orientation = orientation * np.pi/180 |
|
frequency = frequency * 30 |
|
phase = phase * np.pi/180 |
|
|
|
# Assign images to intervals |
|
interval_labels = np.zeros(N, dtype=int) |
|
count = np.zeros(Nimages) |
|
if intervals=="equal_decoding": |
|
intervals = np.linspace(min(decoding), max(decoding), Nimages+1) |
|
for i in range(Nimages): |
|
for j in range(N): |
|
if (intervals[i] <= decoding[j]) and (decoding[j] <= intervals[i+1]): |
|
interval_labels[j] = i |
|
count[i] += 1 |
|
elif intervals=="equal_images": |
|
interval_labels = (np.floor(np.linspace(0,Nimages-0.01,N))).astype(int) |
|
interval_labels = interval_labels[np.argsort(decoding)] |
|
for i in range(Nimages): |
|
count[i] = (interval_labels[interval_labels==i]).shape[0] |
|
else: |
|
raise ValueError("Invalid intervals type") |
|
return |
|
|
|
# Average over grating images |
|
av_images = np.zeros([Nimages, Npixels, Npixels]) |
|
grouped_indices = [[] for _ in range(Nimages)] |
|
iterator = trange(0, N, position=0, leave=True) |
|
iterator.set_description("Averaging images") |
|
for j in iterator: |
|
pars = np.array([orientation[j], frequency[j], phase[j], contrast[j]]) |
|
grouped_indices[interval_labels[j]].append(pars) |
|
if random.choice([True, False]): |
|
av_images[interval_labels[j]] += ( |
|
grating_image(pars, N=Npixels, plot=False) |
|
) |
|
for i in range(Nimages): |
|
if count[i] != 0: |
|
av_images[i] = av_images[i]/count[i] |
|
|
|
print("Number of images averaged over:") |
|
print(count) |
|
|
|
if normalized: |
|
av_images = av_images/np.max(np.abs(av_images)) |
|
|
|
# Show averages and make gif |
|
for i in range(Nimages): |
|
if (not math.isnan(av_images[i,0,0])): |
|
plt.imshow(av_images[i], "gray", vmin = -1, vmax = 1) |
|
if save_path is not None: |
|
plt.savefig(save_path + "_image" + str(i+1) + ".png") |
|
plt.show() |
|
if save_path is not None: |
|
frames = (255*0.5*(1+av_images)).astype(np.uint8) |
|
imageio.mimsave(save_path + ".gif", frames) |
|
|
|
# Plot the grating parameters against the decoding |
|
try: |
|
ori_fun = lambda x : angular_mean(np.array(x)[:,0],period=np.pi,axis=0) |
|
freq_fun = lambda x: np.mean(np.array(x)[:,1],axis=0) |
|
phase_fun = lambda x: angular_mean(np.array(x)[:,2],period=2*np.pi,axis=0) |
|
contrast_fun = lambda x: np.mean(np.array(x)[:,3],axis=0) |
|
av_orientation = np.array(list(map(ori_fun, grouped_indices))) |
|
av_frequency = np.array(list(map(freq_fun, grouped_indices))) |
|
av_phase = np.array(list(map(phase_fun, grouped_indices))) |
|
av_contrast = np.array(list(map(contrast_fun, grouped_indices))) |
|
except IndexError: |
|
print("Error: some bins are empty; choose a smaller bin count.") |
|
else: |
|
if len(set(orientation)) > 1: |
|
plt.scatter(np.arange(Nimages), av_orientation) |
|
plt.xlabel("Decoding") |
|
plt.ylabel("Orientation") |
|
plt.show() |
|
if len(set(frequency)) > 1: |
|
plt.scatter(np.arange(Nimages), av_frequency) |
|
plt.xlabel("Decoding") |
|
plt.ylabel("Frequency") |
|
plt.show() |
|
if len(set(phase)) > 1: |
|
plt.scatter(np.arange(Nimages), av_phase) |
|
plt.xlabel("Decoding") |
|
plt.ylabel("Phase") |
|
plt.show() |
|
if len(set(contrast)) > 1: |
|
plt.scatter(np.arange(Nimages), av_contrast) |
|
plt.xlabel("Decoding") |
|
plt.ylabel("Contrast") |
|
plt.show() |
|
return
|
|
|