Lev
2 years ago
5 changed files with 924 additions and 0 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -0,0 +1,78 @@
|
||||
import typing |
||||
|
||||
import matplotlib as mpl |
||||
import numpy as np |
||||
import matplotlib.pyplot as plt |
||||
|
||||
import sympy as sp |
||||
|
||||
from utils import eval_func, get_orientation_phase_grid, get_spatial_grid |
||||
|
||||
sp.init_printing() |
||||
AxOrImg = typing.Union[mpl.axes.Axes, mpl.image.AxesImage] |
||||
|
||||
|
||||
# %% |
||||
|
||||
|
||||
def plot_spatial(func: sp.Expr, ax: AxOrImg, step_x: float = 0.05, step_y: float = 0.05, size: float = 1, |
||||
title: str = None, show: bool = False, |
||||
patch: typing.Optional[typing.Tuple[float, float, float]] = None |
||||
): |
||||
""" |
||||
Plots a spatial map of the function. |
||||
|
||||
:param func: function to plot |
||||
:param ax: axes to plot on or the image on axes |
||||
:param step_x: step for the x-coordinate |
||||
:param step_y: step for the y-coordinate |
||||
:param size: size of the grid |
||||
:param title: title of the plot |
||||
:param show: whether to show the plot |
||||
:param patch: optional circle to plot - a tuple (x, y, radius) |
||||
""" |
||||
grid = get_spatial_grid(step_x, step_y, size) |
||||
image: np.ndarray = eval_func(func, x, y, grid) |
||||
if isinstance(ax, mpl.image.AxesImage): |
||||
ax.set_data(image) |
||||
return ax |
||||
img = ax.imshow(image, extent=[-size, size, -size, size], vmin=-size, vmax=size, cmap='gray') |
||||
ax.invert_yaxis() |
||||
if patch is not None: |
||||
ax.add_patch(plt.Circle(patch[:2], radius=patch[2], color='b', fill=False)) |
||||
ax.set_title(title) |
||||
if show: |
||||
plt.show() |
||||
return img |
||||
|
||||
|
||||
def normalize(img): |
||||
return (img - img.min()) / (img.max() - img.min()) |
||||
|
||||
|
||||
def plot_tuning_curve(func: typing.Union[sp.Expr, typing.Callable], ax: AxOrImg, step_phase: float = 20, |
||||
step_orientation: float = 15, title: str = None, show: bool = False): |
||||
""" |
||||
Plots a tuning curve of the function. |
||||
|
||||
:param func: function to plot - sympy or a function of (theta, phi) |
||||
:param ax: axes to plot on or image to update |
||||
:param step_phase: step for the phase (phi) - in degrees |
||||
:param step_orientation: step for the orientation (theta) - in degrees |
||||
:param title: title of the plot |
||||
:param show: whether to show the plot |
||||
""" |
||||
grid = get_orientation_phase_grid(step_phase, step_orientation) |
||||
if isinstance(func, sp.Expr): |
||||
image: np.ndarray = eval_func(func, theta, phi, grid) |
||||
else: |
||||
image = np.array([[func(theta_val, phi_val) for theta_val, phi_val in line] for line in grid]) |
||||
image = normalize(image) |
||||
if isinstance(ax, mpl.image.AxesImage): |
||||
ax.set_data(image) |
||||
return ax |
||||
img = ax.imshow(image, extent=[0, 360, 0, 180], cmap='viridis') |
||||
ax.set_title(title) |
||||
if show: |
||||
plt.show() |
||||
return img |
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations |
||||
import typing |
||||
from dataclasses import dataclass, field |
||||
|
||||
import matplotlib as mpl |
||||
import sympy as sp |
||||
import numpy as np |
||||
|
||||
from utils import get_orientation_phase_grid |
||||
|
||||
sp.init_printing() |
||||
k, x0, y0, phi, theta, sigma_x, sigma_y, sigma, x, y = sp.symbols(r'k x_0 y_0 \phi \theta \sigma_x \sigma_y \sigma x y') |
||||
defaults = { |
||||
k: 6, |
||||
sigma: 0.2, |
||||
phi: sp.pi / 2, |
||||
theta: 0, |
||||
sigma: 1, |
||||
x0: 0, y0: 0 |
||||
} |
||||
sigma_x = sigma_y = sigma |
||||
grating_f = sp.cos(k * (x - x0) * sp.cos(theta) + k * (y - y0) * sp.sin(theta) + phi) |
||||
receptive_field = 1 / (2 * sp.pi * sigma * sigma) * sp.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) * sp.cos( |
||||
k * x * sp.cos(theta) + k * y * sp.sin(theta) + phi) |
||||
receptive_field = receptive_field.subs(theta, 0).subs(phi, 0) |
||||
p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos( |
||||
phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) |
||||
|
||||
sigma_split = np.arange(0.1, 1, 0.05) |
||||
k_split = np.arange(0.2, 6, 0.2) |
||||
xy_split = np.arange(-1, 1, 0.05) |
||||
|
||||
|
||||
@dataclass |
||||
class Cell: |
||||
sigma_val: float = defaults[sigma] |
||||
x0_val: float = defaults[x0] |
||||
y0_val: float = defaults[y0] |
||||
k_val: float = defaults[k] |
||||
|
||||
@classmethod |
||||
def random(cls, sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
||||
k_val: float = defaults[k], |
||||
xy_dist: np.ndarray = np.ones(len(xy_split))): |
||||
return cls( |
||||
sigma_val=np.random.choice(sigma_split, p=sigma_dist / np.sum(sigma_dist)), |
||||
x0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
||||
y0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
||||
k_val=k_val |
||||
) |
||||
|
||||
@property |
||||
def sympy_func(self) -> sp.Expr: |
||||
return receptive_field.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k, |
||||
self.k_val) |
||||
|
||||
def get_tuning_function(self) -> typing.Callable[[np.ndarray, np.ndarray], np.ndarray]: |
||||
""" |
||||
Get the tuning sympy function as a numpy lambda function of theta and phi. |
||||
:return: a function (theta, phi) -> value |
||||
""" |
||||
return sp.lambdify( |
||||
(theta, phi), |
||||
p.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k, self.k_val), |
||||
'numpy') |
||||
|
||||
def get_value(self, theta_deg: float, phi_deg: float) -> float: |
||||
return float(self.get_tuning_function()(theta_deg * np.pi / 180, phi_deg * np.pi / 180)) |
||||
|
||||
def get_tuning_plot(self, theta_step_deg: float, phi_step_deg: float) -> np.ndarray: |
||||
grid = get_orientation_phase_grid(theta_step_deg, phi_step_deg) |
||||
return self.get_tuning_function()(grid[:, :, 0], grid[:, :, 1]) |
||||
|
||||
|
||||
@dataclass |
||||
class Grating: |
||||
k_val: float = defaults[k] |
||||
phi_val: float = defaults[phi] |
||||
theta_val: float = defaults[theta] |
||||
|
||||
@property |
||||
def sympy_func(self) -> sp.Expr: |
||||
return grating_f.subs(k, self.k_val).subs(phi, self.phi_val).subs(theta, self.theta_val) |
||||
|
||||
|
||||
@dataclass |
||||
class Population: |
||||
cells: typing.List[Cell] = field(default_factory=list) |
||||
|
||||
@classmethod |
||||
def random(cls, n: int, sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
||||
k_val: float = defaults[k], |
||||
xy_dist: np.ndarray = np.ones(len(xy_split))): |
||||
return cls(cells=[Cell.random(sigma_dist, k_val, xy_dist) for _ in range(n)]) |
||||
|
||||
def get_response(self, phi_deg: float, theta_deg: float) -> typing.List[float]: |
||||
return [cell.get_value(theta_deg, phi_deg) for cell in self.cells] |
||||
|
||||
def sample_responses(self, n: int) -> np.ndarray: |
||||
return np.array([ |
||||
self.get_response(phi_deg, theta_deg % 180) |
||||
for phi_deg, theta_deg in np.random.uniform(0, 360, (n, 2)) |
||||
]) |
@ -0,0 +1,37 @@
|
||||
import numpy as np |
||||
import sympy as sp |
||||
|
||||
|
||||
def eval_func(func: sp.Expr, sub_1: sp.Expr, sub_2: sp.Expr, grid: np.ndarray) -> np.ndarray: |
||||
# return np.array([[float(func.subs(sub_1, x_).subs(sub_2, y_)) for x_, y_ in line] for line in grid]) |
||||
func = sp.lambdify([sub_1, sub_2], func, 'numpy') |
||||
return func(grid[:, :, 0], grid[:, :, 1]) |
||||
|
||||
|
||||
def get_orientation_phase_grid(step_phase: float, step_orientation: float) -> np.ndarray: |
||||
""" |
||||
Returns a grid of x and y values for plotting. |
||||
:param step_phase: step for the phase (phi) - in degrees |
||||
:param step_orientation: step for the orientation (theta) - in degrees |
||||
:return: numpy array of shape (n_orientation, n_phase). Each element is a tuple (theta, phi) |
||||
""" |
||||
# phase <-> phi |
||||
# orientation <-> theta |
||||
step_phase *= np.pi / 180 |
||||
step_orientation *= np.pi / 180 |
||||
phi = np.arange(0, 2 * np.pi, step_phase) |
||||
theta = np.arange(0, np.pi, step_orientation) |
||||
return np.array(np.meshgrid(theta, phi)).T.reshape(-1, len(phi), 2) |
||||
|
||||
|
||||
def get_spatial_grid(step_x: float, step_y: float, size: float = 1) -> np.ndarray: |
||||
""" |
||||
Returns a grid of x and y values for plotting. |
||||
:param step_x: step for the x-coordinate |
||||
:param step_y: step for the y-coordinate |
||||
:param size: size of the grid |
||||
:return: numpy array of shape (2 * size / step_x, 2 * size / step_y). Each element is a tuple (x, y) |
||||
""" |
||||
x = np.arange(-size, size, step_x) |
||||
y = np.arange(-size, size, step_y) |
||||
return np.array(np.meshgrid(x, y)).T.reshape(-1, len(x), 2) |
Loading…
Reference in new issue