You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
3.8 KiB
104 lines
3.8 KiB
2 years ago
|
from __future__ import annotations
|
||
|
import typing
|
||
|
from dataclasses import dataclass, field
|
||
|
|
||
|
import matplotlib as mpl
|
||
|
import sympy as sp
|
||
|
import numpy as np
|
||
|
|
||
|
from utils import get_orientation_phase_grid
|
||
|
|
||
|
sp.init_printing()
|
||
|
k, x0, y0, phi, theta, sigma_x, sigma_y, sigma, x, y = sp.symbols(r'k x_0 y_0 \phi \theta \sigma_x \sigma_y \sigma x y')
|
||
|
defaults = {
|
||
|
k: 6,
|
||
|
sigma: 0.2,
|
||
|
phi: sp.pi / 2,
|
||
|
theta: 0,
|
||
|
sigma: 1,
|
||
|
x0: 0, y0: 0
|
||
|
}
|
||
|
sigma_x = sigma_y = sigma
|
||
|
grating_f = sp.cos(k * (x - x0) * sp.cos(theta) + k * (y - y0) * sp.sin(theta) + phi)
|
||
|
receptive_field = 1 / (2 * sp.pi * sigma * sigma) * sp.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) * sp.cos(
|
||
|
k * x * sp.cos(theta) + k * y * sp.sin(theta) + phi)
|
||
|
receptive_field = receptive_field.subs(theta, 0).subs(phi, 0)
|
||
|
p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos(
|
||
|
phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta)))
|
||
|
|
||
|
sigma_split = np.arange(0.1, 1, 0.05)
|
||
|
k_split = np.arange(0.2, 6, 0.2)
|
||
|
xy_split = np.arange(-1, 1, 0.05)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Cell:
|
||
|
sigma_val: float = defaults[sigma]
|
||
|
x0_val: float = defaults[x0]
|
||
|
y0_val: float = defaults[y0]
|
||
|
k_val: float = defaults[k]
|
||
|
|
||
|
@classmethod
|
||
|
def random(cls, sigma_dist: np.ndarray = np.ones(len(sigma_split)),
|
||
|
k_val: float = defaults[k],
|
||
|
xy_dist: np.ndarray = np.ones(len(xy_split))):
|
||
|
return cls(
|
||
|
sigma_val=np.random.choice(sigma_split, p=sigma_dist / np.sum(sigma_dist)),
|
||
|
x0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)),
|
||
|
y0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)),
|
||
|
k_val=k_val
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def sympy_func(self) -> sp.Expr:
|
||
|
return receptive_field.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k,
|
||
|
self.k_val)
|
||
|
|
||
|
def get_tuning_function(self) -> typing.Callable[[np.ndarray, np.ndarray], np.ndarray]:
|
||
|
"""
|
||
|
Get the tuning sympy function as a numpy lambda function of theta and phi.
|
||
|
:return: a function (theta, phi) -> value
|
||
|
"""
|
||
|
return sp.lambdify(
|
||
|
(theta, phi),
|
||
|
p.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k, self.k_val),
|
||
|
'numpy')
|
||
|
|
||
|
def get_value(self, theta_deg: float, phi_deg: float) -> float:
|
||
|
return float(self.get_tuning_function()(theta_deg * np.pi / 180, phi_deg * np.pi / 180))
|
||
|
|
||
|
def get_tuning_plot(self, theta_step_deg: float, phi_step_deg: float) -> np.ndarray:
|
||
|
grid = get_orientation_phase_grid(theta_step_deg, phi_step_deg)
|
||
|
return self.get_tuning_function()(grid[:, :, 0], grid[:, :, 1])
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Grating:
|
||
|
k_val: float = defaults[k]
|
||
|
phi_val: float = defaults[phi]
|
||
|
theta_val: float = defaults[theta]
|
||
|
|
||
|
@property
|
||
|
def sympy_func(self) -> sp.Expr:
|
||
|
return grating_f.subs(k, self.k_val).subs(phi, self.phi_val).subs(theta, self.theta_val)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Population:
|
||
|
cells: typing.List[Cell] = field(default_factory=list)
|
||
|
|
||
|
@classmethod
|
||
|
def random(cls, n: int, sigma_dist: np.ndarray = np.ones(len(sigma_split)),
|
||
|
k_val: float = defaults[k],
|
||
|
xy_dist: np.ndarray = np.ones(len(xy_split))):
|
||
|
return cls(cells=[Cell.random(sigma_dist, k_val, xy_dist) for _ in range(n)])
|
||
|
|
||
|
def get_response(self, phi_deg: float, theta_deg: float) -> typing.List[float]:
|
||
|
return [cell.get_value(theta_deg, phi_deg) for cell in self.cells]
|
||
|
|
||
|
def sample_responses(self, n: int) -> np.ndarray:
|
||
|
return np.array([
|
||
|
self.get_response(phi_deg, theta_deg % 180)
|
||
|
for phi_deg, theta_deg in np.random.uniform(0, 360, (n, 2))
|
||
|
])
|