# -*- coding: utf-8 -*- """ Some plotting functions related to grating stimuli responses """ import numpy as np import pandas as pd import math import random import plotly.graph_objects as go from plotly.offline import plot from plotly.subplots import make_subplots import matplotlib.pyplot as plt import imageio import sklearn.manifold as manifold from sklearn.decomposition import PCA from tqdm import trange from decorators import multi_input from gratings import angular_mean, grating_image @multi_input def plot_data(data, transformation="None", labels=None, colors=None, save_path=None): """ Plot data colored by its indices and additional provided labels Parameters ---------- data: dataframe(n_datapoints, n_features): Dataframe containing the data transformation: str, optional, default "None" The type of dimension reduction used Choose from "None", PCA" or "SpectralEmbedding" labels: dataframe(n_datapoints, n_labels), optional, default None Dataframe containing additional labels to be plotted colors: list of str, optional, default None A list containing the color scales used for each label When None use "Viridis" for all labels save_path: str, optional, default None When given save the figure here Raises ------ ValueError When an invalid value for "transformation" is given """ # Set labels indices = data.index plotted_labels = indices.to_frame() if labels is not None: plotted_labels = plotted_labels.join(labels) if colors is None: colors = ["Viridis"]*len(plotted_labels.columns) Nlabels = len(plotted_labels.columns) # Transform data to lower dimension if (transformation == "PCA"): pca = PCA(n_components=3) pca.fit(data) data = pca.transform(data) elif (transformation == "SpectralEmbedding"): embedding = manifold.SpectralEmbedding(n_components=3, affinity='rbf') data = embedding.fit_transform(data) elif (transformation == "None"): pass else: raise ValueError("Invalid plot transformation") # Plot data = pd.DataFrame(data) data.index = indices fig = go.Figure() fig = make_subplots(rows=2, cols=math.ceil(Nlabels/2), specs=[[{'type': 'scene'}]*math.ceil(Nlabels/2)]*2) for i,label in enumerate(plotted_labels): fig.add_trace( go.Scatter3d( mode='markers', name=label, x=data[0], y=data[1], z=data[2], text = plotted_labels[label], hoverinfo = ['x','text'], marker=dict( color=plotted_labels[label], size=5, sizemode='diameter', colorscale=colors[i] ) ), row=i%2 + 1, col=math.floor(i/2)+1 ) fig.update_layout(height=900, width=1600, title_text="") if save_path is None: plot(fig) fig.show() else: path = save_path path += 'plot' if transformation != 'None': path += '_' + transformation path += ".html" fig.write_html(path) return def plot_connections(data_points, connections, threshold=0.1, opacity=0.1, save_path=None): #draw a square x = data_points[:,0] y = data_points[:,1] z = data_points[:,2] N = len(x) #the start and end point for each line pairs = [(i,j) for i,j in np.ndindex((N,N)) if connections[i,j] > threshold] trace1 = go.Scatter3d( x=x, y=y, z=z, mode='markers', name='markers' ) x_lines = list() y_lines = list() z_lines = list() #create the coordinate list for the lines for p in pairs: for i in range(2): x_lines.append(x[p[i]]) y_lines.append(y[p[i]]) z_lines.append(z[p[i]]) x_lines.append(None) y_lines.append(None) z_lines.append(None) trace2 = go.Scatter3d( x=x_lines, y=y_lines, z=z_lines, opacity=opacity, mode='lines', name='lines' ) fig = go.Figure(data=[trace1,trace2]) if save_path is not None: path = save_path path += '/plot_glue.html' fig.write_html(path) else: plot(fig) return @multi_input def plot_mean_against_index(data, value, index, circ=True, save_path=None): """ Plot the mean value against an index Parameters ---------- data: dataframe(n_datapoints, n_features): Dataframe containing the data value : dataframe(n_datapoints) The value we average over index : str The name of the index we plot against circ : bool, optional, default True Whether or not the index is angular """ unique_index = data.reset_index()[index].unique() if len(unique_index) > 1: if circ: means = value.groupby([index]).agg(lambda X: angular_mean(X, period=1)) else: means = value.groupby([index]).mean() plt.scatter(means.index, means) if circ: plt.ylim(0, 1) plt.xlabel(index) plt.ylabel('mean') plt.show() return @multi_input def show_feature(decoding, Nimages=10, Npixels=100, normalized=True, intervals="equal_images", save_path=None): """ Show how the gratings depend on a decoded parameter Shows the average grating for different values of the decoding and plot the grating parameters as a function of the decoding Parameters ---------- decoding : dataframe(n_datapoints) A dataframe containing the decoded value for each data point labeled by indices "orientation, "frequency", "phase" and "contrast" Nimages : int, optional, default 10 Number of different images shown Npixels: int, optional, default 100 The number of pixels in each direction normalized : bool, optional, default True If true normalize the average images intervals : str, optional, default "equal_images" How the images are binned together - When "equal_images" average such that each image is an average of an equal number of images - When "equal_decoding" average such that each image is averaged from images within an equal fraction of the decoding save_path: str, optional, default None If given, save a gif here Raises ------ ValueError When an invalid value for "intervals" is given Warns ----- When some of the bins are empty """ try: orientation = decoding.reset_index()["orientation"] except KeyError: orientation = pd.Series(np.ones(len(decoding))) try: contrast = decoding.reset_index()["contrast"] except KeyError: contrast = pd.Series(np.ones(len(decoding))) try: frequency = decoding.reset_index()["frequency"] except KeyError: frequency = pd.Series(np.ones(len(decoding))) try: phase = decoding.reset_index()["phase"] except KeyError: phase = pd.Series(np.ones(len(decoding))) decoding = decoding.to_numpy().ravel() N = decoding.shape[0] # Convert grating labels if (max(phase) > 2*np.pi): orientation = orientation * np.pi/180 frequency = frequency * 30 phase = phase * np.pi/180 # Assign images to intervals interval_labels = np.zeros(N, dtype=int) count = np.zeros(Nimages) if intervals=="equal_decoding": intervals = np.linspace(min(decoding), max(decoding), Nimages+1) for i in range(Nimages): for j in range(N): if (intervals[i] <= decoding[j]) and (decoding[j] <= intervals[i+1]): interval_labels[j] = i count[i] += 1 elif intervals=="equal_images": interval_labels = (np.floor(np.linspace(0,Nimages-0.01,N))).astype(int) interval_labels = interval_labels[np.argsort(decoding)] for i in range(Nimages): count[i] = (interval_labels[interval_labels==i]).shape[0] else: raise ValueError("Invalid intervals type") return # Average over grating images av_images = np.zeros([Nimages, Npixels, Npixels]) grouped_indices = [[] for _ in range(Nimages)] iterator = trange(0, N, position=0, leave=True) iterator.set_description("Averaging images") for j in iterator: pars = np.array([orientation[j], frequency[j], phase[j], contrast[j]]) grouped_indices[interval_labels[j]].append(pars) if random.choice([True, False]): av_images[interval_labels[j]] += ( grating_image(pars, N=Npixels, plot=False) ) for i in range(Nimages): if count[i] != 0: av_images[i] = av_images[i]/count[i] print("Number of images averaged over:") print(count) if normalized: av_images = av_images/np.max(np.abs(av_images)) # Show averages and make gif for i in range(Nimages): if (not math.isnan(av_images[i,0,0])): plt.imshow(av_images[i], "gray", vmin = -1, vmax = 1) if save_path is not None: plt.savefig(save_path + "_image" + str(i+1) + ".png") plt.show() if save_path is not None: frames = (255*0.5*(1+av_images)).astype(np.uint8) imageio.mimsave(save_path + ".gif", frames) # Plot the grating parameters against the decoding try: ori_fun = lambda x : angular_mean(np.array(x)[:,0],period=np.pi,axis=0) freq_fun = lambda x: np.mean(np.array(x)[:,1],axis=0) phase_fun = lambda x: angular_mean(np.array(x)[:,2],period=2*np.pi,axis=0) contrast_fun = lambda x: np.mean(np.array(x)[:,3],axis=0) av_orientation = np.array(list(map(ori_fun, grouped_indices))) av_frequency = np.array(list(map(freq_fun, grouped_indices))) av_phase = np.array(list(map(phase_fun, grouped_indices))) av_contrast = np.array(list(map(contrast_fun, grouped_indices))) except IndexError: print("Error: some bins are empty; choose a smaller bin count.") else: if len(set(orientation)) > 1: plt.scatter(np.arange(Nimages), av_orientation) plt.xlabel("Decoding") plt.ylabel("Orientation") plt.show() if len(set(frequency)) > 1: plt.scatter(np.arange(Nimages), av_frequency) plt.xlabel("Decoding") plt.ylabel("Frequency") plt.show() if len(set(phase)) > 1: plt.scatter(np.arange(Nimages), av_phase) plt.xlabel("Decoding") plt.ylabel("Phase") plt.show() if len(set(contrast)) > 1: plt.scatter(np.arange(Nimages), av_contrast) plt.xlabel("Decoding") plt.ylabel("Contrast") plt.show() return