import typing import matplotlib as mpl import numpy as np import matplotlib.pyplot as plt import sympy as sp from utils import eval_func, get_orientation_phase_grid, get_spatial_grid sp.init_printing() AxOrImg = typing.Union[mpl.axes.Axes, mpl.image.AxesImage] # %% def plot_spatial(func: sp.Expr, ax: AxOrImg, step_x: float = 0.05, step_y: float = 0.05, size: float = 1, title: str = None, show: bool = False, patch: typing.Optional[typing.Tuple[float, float, float]] = None ): """ Plots a spatial map of the function. :param func: function to plot :param ax: axes to plot on or the image on axes :param step_x: step for the x-coordinate :param step_y: step for the y-coordinate :param size: size of the grid :param title: title of the plot :param show: whether to show the plot :param patch: optional circle to plot - a tuple (x, y, radius) """ grid = get_spatial_grid(step_x, step_y, size) image: np.ndarray = eval_func(func, x, y, grid) if isinstance(ax, mpl.image.AxesImage): ax.set_data(image) return ax img = ax.imshow(image, extent=[-size, size, -size, size], vmin=-size, vmax=size, cmap='gray') ax.invert_yaxis() if patch is not None: ax.add_patch(plt.Circle(patch[:2], radius=patch[2], color='b', fill=False)) ax.set_title(title) if show: plt.show() return img def normalize(img): return (img - img.min()) / (img.max() - img.min()) def plot_tuning_curve(func: typing.Union[sp.Expr, typing.Callable], ax: AxOrImg, step_phase: float = 20, step_orientation: float = 15, title: str = None, show: bool = False): """ Plots a tuning curve of the function. :param func: function to plot - sympy or a function of (theta, phi) :param ax: axes to plot on or image to update :param step_phase: step for the phase (phi) - in degrees :param step_orientation: step for the orientation (theta) - in degrees :param title: title of the plot :param show: whether to show the plot """ grid = get_orientation_phase_grid(step_phase, step_orientation) if isinstance(func, sp.Expr): image: np.ndarray = eval_func(func, theta, phi, grid) else: image = np.array([[func(theta_val, phi_val) for theta_val, phi_val in line] for line in grid]) image = normalize(image) if isinstance(ax, mpl.image.AxesImage): ax.set_data(image) return ax img = ax.imshow(image, extent=[0, 360, 0, 180], cmap='viridis') ax.set_title(title) if show: plt.show() return img