Topology in neuroscience
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.6 KiB

import typing
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
from utils import eval_func, get_orientation_phase_grid, get_spatial_grid
AxOrImg = typing.Union[mpl.axes.Axes, mpl.image.AxesImage]
# %%
def plot_spatial(func: sp.Expr, ax: AxOrImg, step_x: float = 0.05, step_y: float = 0.05, size: float = 1,
title: str = None, show: bool = False,
patch: typing.Optional[typing.Tuple[float, float, float]] = None
Plots a spatial map of the function.
:param func: function to plot
:param ax: axes to plot on or the image on axes
:param step_x: step for the x-coordinate
:param step_y: step for the y-coordinate
:param size: size of the grid
:param title: title of the plot
:param show: whether to show the plot
:param patch: optional circle to plot - a tuple (x, y, radius)
grid = get_spatial_grid(step_x, step_y, size)
image: np.ndarray = eval_func(func, x, y, grid)
if isinstance(ax, mpl.image.AxesImage):
return ax
img = ax.imshow(image, extent=[-size, size, -size, size], vmin=-size, vmax=size, cmap='gray')
if patch is not None:
ax.add_patch(plt.Circle(patch[:2], radius=patch[2], color='b', fill=False))
if show:
return img
def normalize(img):
return (img - img.min()) / (img.max() - img.min())
def plot_tuning_curve(func: typing.Union[sp.Expr, typing.Callable], ax: AxOrImg, step_phase: float = 20,
step_orientation: float = 15, title: str = None, show: bool = False):
Plots a tuning curve of the function.
:param func: function to plot - sympy or a function of (theta, phi)
:param ax: axes to plot on or image to update
:param step_phase: step for the phase (phi) - in degrees
:param step_orientation: step for the orientation (theta) - in degrees
:param title: title of the plot
:param show: whether to show the plot
grid = get_orientation_phase_grid(step_phase, step_orientation)
if isinstance(func, sp.Expr):
image: np.ndarray = eval_func(func, theta, phi, grid)
image = np.array([[func(theta_val, phi_val) for theta_val, phi_val in line] for line in grid])
image = normalize(image)
if isinstance(ax, mpl.image.AxesImage):
return ax
img = ax.imshow(image, extent=[0, 360, 0, 180], cmap='viridis')
if show:
return img