|
|
@ -1,6 +1,7 @@ |
|
|
|
from __future__ import annotations |
|
|
|
from __future__ import annotations |
|
|
|
import typing |
|
|
|
import typing |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
import matplotlib as mpl |
|
|
|
import matplotlib as mpl |
|
|
|
import sympy as sp |
|
|
|
import sympy as sp |
|
|
@ -9,29 +10,41 @@ import numpy as np |
|
|
|
from utils import get_orientation_phase_grid |
|
|
|
from utils import get_orientation_phase_grid |
|
|
|
|
|
|
|
|
|
|
|
sp.init_printing() |
|
|
|
sp.init_printing() |
|
|
|
k, x0, y0, phi, theta, sigma_x, sigma_y, sigma, x, y = sp.symbols(r'k x_0 y_0 \phi \theta \sigma_x \sigma_y \sigma x y') |
|
|
|
k, x0, y0, phi_rf, theta_rf, sigma_x, sigma_y, sigma, x, y, theta_grating, phi_grating = sp.symbols(r'k x_0 y_0 \phi_{rf} \theta_{rf} \sigma_x \sigma_y \sigma x y \theta_{grating} \phi_{grating}') |
|
|
|
defaults = { |
|
|
|
defaults = { |
|
|
|
k: 6, |
|
|
|
k: 6, |
|
|
|
sigma: 0.2, |
|
|
|
sigma: 0.2, |
|
|
|
phi: sp.pi / 2, |
|
|
|
phi_rf: sp.pi / 2, |
|
|
|
theta: 0, |
|
|
|
phi_grating: sp.pi / 2, |
|
|
|
|
|
|
|
theta_grating: 0, |
|
|
|
sigma: 1, |
|
|
|
sigma: 1, |
|
|
|
x0: 0, y0: 0 |
|
|
|
x0: 0, y0: 0 |
|
|
|
} |
|
|
|
} |
|
|
|
sigma_x = sigma_y = sigma |
|
|
|
sigma_x = sigma_y = sigma |
|
|
|
grating_f = sp.cos(k * (x - x0) * sp.cos(theta) + k * (y - y0) * sp.sin(theta) + phi) |
|
|
|
grating_f = sp.cos(k * (x - x0) * sp.cos(theta_grating) + k * (y - y0) * sp.sin(theta_grating) + phi_grating) |
|
|
|
receptive_field = 1 / (2 * sp.pi * sigma * sigma) * sp.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) * sp.cos( |
|
|
|
receptive_field = 1 / (2 * sp.pi * sigma * sigma) * sp.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) * sp.cos( |
|
|
|
k * x * sp.cos(theta) + k * y * sp.sin(theta) + phi) |
|
|
|
k * x * sp.cos(theta_rf) + k * y * sp.sin(theta_rf) + phi_rf) |
|
|
|
receptive_field = receptive_field.subs(theta, 0).subs(phi, 0) |
|
|
|
|
|
|
|
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos( |
|
|
|
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos( |
|
|
|
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) |
|
|
|
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) |
|
|
|
|
|
|
|
|
|
|
|
p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta) * 4) * sp.exp(-4 * k ** 2 * sigma ** 2) * sp.cos( |
|
|
|
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta) * 4) * sp.exp(-4 * k ** 2 * sigma ** 2) * sp.cos( |
|
|
|
phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) |
|
|
|
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p = (1 / 2) * sp.exp(-k*k*sigma*sigma) * ( |
|
|
|
|
|
|
|
sp.exp(-k*k*sigma*sigma*sp.sin(theta_grating + theta_rf)) * sp.cos(phi_grating + phi_rf + 2 * k / (sigma * sigma) * ( |
|
|
|
|
|
|
|
x0 * sp.cos(theta_rf) + y0 * sp.sin(theta_grating) |
|
|
|
|
|
|
|
+ x0 * sp.sin(theta_rf) + y0 * sp.cos(theta_grating) |
|
|
|
|
|
|
|
)) + |
|
|
|
|
|
|
|
sp.exp( k*k*sigma*sigma*sp.sin(theta_grating + theta_rf)) * sp.cos(phi_grating - phi_rf + 2 * k / (sigma * sigma) * ( |
|
|
|
|
|
|
|
-x0 * sp.cos(theta_rf) - y0 * sp.sin(theta_rf) |
|
|
|
|
|
|
|
+ x0 * sp.sin(theta_rf) + y0 * sp.cos(theta_grating) |
|
|
|
|
|
|
|
))) |
|
|
|
|
|
|
|
|
|
|
|
sigma_split = np.arange(0.1, 1, 0.05) |
|
|
|
sigma_split = np.arange(0.1, 1, 0.05) |
|
|
|
k_split = np.arange(0.2, 6, 0.2) |
|
|
|
k_split = np.arange(0.2, 6, 0.2) |
|
|
|
xy_split = np.arange(-1, 1, 0.05) |
|
|
|
xy_split = np.arange(-1, 1, 0.05) |
|
|
|
|
|
|
|
phi_split = np.arange(0, 2 * np.pi, np.pi / 100) |
|
|
|
|
|
|
|
theta_split = np.arange(0, np.pi, np.pi / 100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid(x): |
|
|
|
def sigmoid(x): |
|
|
@ -40,26 +53,34 @@ def sigmoid(x): |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
@dataclass |
|
|
|
class Cell: |
|
|
|
class Cell: |
|
|
|
|
|
|
|
phi_val: float |
|
|
|
|
|
|
|
theta_val: float |
|
|
|
sigma_val: float = defaults[sigma] |
|
|
|
sigma_val: float = defaults[sigma] |
|
|
|
x0_val: float = defaults[x0] |
|
|
|
x0_val: float = defaults[x0] |
|
|
|
y0_val: float = defaults[y0] |
|
|
|
y0_val: float = defaults[y0] |
|
|
|
k_val: float = defaults[k] |
|
|
|
k_val: float = defaults[k] |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
@classmethod |
|
|
|
def random(cls, sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
|
|
|
def random(cls, |
|
|
|
|
|
|
|
phi_dist: np.ndarray = np.ones(len(phi_split)), |
|
|
|
|
|
|
|
theta_dist: np.ndarray = np.ones(len(theta_split)), |
|
|
|
|
|
|
|
sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
|
|
|
k_val: float = defaults[k], |
|
|
|
k_val: float = defaults[k], |
|
|
|
xy_dist: np.ndarray = np.ones(len(xy_split))): |
|
|
|
xy_dist: np.ndarray = np.ones(len(xy_split))): |
|
|
|
return cls( |
|
|
|
return cls( |
|
|
|
sigma_val=np.random.choice(sigma_split, p=sigma_dist / np.sum(sigma_dist)), |
|
|
|
sigma_val=1, # np.random.choice(sigma_split, p=sigma_dist / np.sum(sigma_dist)), |
|
|
|
x0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
|
|
|
phi_val=np.random.choice(phi_split, p=phi_dist / np.sum(phi_dist)), |
|
|
|
y0_val=np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
|
|
|
theta_val=np.random.choice(theta_split, p=theta_dist / np.sum(theta_dist)), |
|
|
|
k_val=k_val |
|
|
|
x0_val=0, # np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
|
|
|
|
|
|
|
y0_val=0, # np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)), |
|
|
|
|
|
|
|
k_val=k_val, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
@property |
|
|
|
def sympy_func(self) -> sp.Expr: |
|
|
|
def sympy_func(self) -> sp.Expr: |
|
|
|
return receptive_field.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k, |
|
|
|
return receptive_field\ |
|
|
|
self.k_val) |
|
|
|
.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val)\ |
|
|
|
|
|
|
|
.subs(k, self.k_val).subs(phi_rf, self.phi_val).subs(theta_rf, self.theta_val) |
|
|
|
|
|
|
|
|
|
|
|
def get_tuning_function(self) -> typing.Callable[[np.ndarray, np.ndarray], np.ndarray]: |
|
|
|
def get_tuning_function(self) -> typing.Callable[[np.ndarray, np.ndarray], np.ndarray]: |
|
|
|
""" |
|
|
|
""" |
|
|
@ -67,8 +88,10 @@ class Cell: |
|
|
|
:return: a function (theta, phi) -> value |
|
|
|
:return: a function (theta, phi) -> value |
|
|
|
""" |
|
|
|
""" |
|
|
|
return sp.lambdify( |
|
|
|
return sp.lambdify( |
|
|
|
(theta, phi), |
|
|
|
(theta_grating, phi_grating), |
|
|
|
p.subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val).subs(k, self.k_val), |
|
|
|
p.subs(sigma, self.sigma_val).subs(x0, self.x0_val)\ |
|
|
|
|
|
|
|
.subs(y0, self.y0_val).subs(k, self.k_val)\ |
|
|
|
|
|
|
|
.subs(phi_rf, self.phi_val).subs(theta_rf, self.theta_val), |
|
|
|
'numpy') |
|
|
|
'numpy') |
|
|
|
|
|
|
|
|
|
|
|
def get_value(self, theta_deg: float, phi_deg: float) -> float: |
|
|
|
def get_value(self, theta_deg: float, phi_deg: float) -> float: |
|
|
@ -81,24 +104,43 @@ class Cell: |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
@dataclass |
|
|
|
class Grating: |
|
|
|
class Grating: |
|
|
|
|
|
|
|
phi_val: float = defaults[phi_grating] |
|
|
|
|
|
|
|
theta_val: float = defaults[theta_grating] |
|
|
|
k_val: float = defaults[k] |
|
|
|
k_val: float = defaults[k] |
|
|
|
phi_val: float = defaults[phi] |
|
|
|
|
|
|
|
theta_val: float = defaults[theta] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
@property |
|
|
|
def sympy_func(self) -> sp.Expr: |
|
|
|
def sympy_func(self) -> sp.Expr: |
|
|
|
return grating_f.subs(k, self.k_val).subs(phi, self.phi_val).subs(theta, self.theta_val) |
|
|
|
return grating_f.subs(k, self.k_val).subs(phi_grating, self.phi_val).subs(theta_grating, self.theta_val) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
@dataclass |
|
|
|
class Population: |
|
|
|
class Population: |
|
|
|
cells: typing.List[Cell] = field(default_factory=list) |
|
|
|
cells: typing.List[Cell] = field(default_factory=list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
|
def response_func(self) -> typing.Callable[[float, float], np.ndarray]: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Use sp.lambdify and the expression to generate the necessary function. |
|
|
|
|
|
|
|
:return: a function (phi, theta) -> responses |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
return partial( |
|
|
|
|
|
|
|
sp.lambdify( |
|
|
|
|
|
|
|
(x0, y0, k, sigma, phi_rf, theta_rf, phi_grating, theta_grating, ), |
|
|
|
|
|
|
|
p, 'numpy'), |
|
|
|
|
|
|
|
0, 0, np.array([cell.k_val for cell in self.cells]).reshape((-1, 1)), |
|
|
|
|
|
|
|
np.array([cell.sigma_val for cell in self.cells]).reshape((-1, 1)), |
|
|
|
|
|
|
|
np.array([cell.phi_val for cell in self.cells]).reshape((-1, 1)), |
|
|
|
|
|
|
|
np.array([cell.theta_val for cell in self.cells]).reshape((-1, 1)), |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
@classmethod |
|
|
|
def random(cls, n: int, sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
|
|
|
def random(cls, n: int, |
|
|
|
|
|
|
|
phi_dist: np.ndarray = np.ones(len(phi_split)), |
|
|
|
|
|
|
|
theta_dist: np.ndarray = np.ones(len(theta_split)), |
|
|
|
|
|
|
|
sigma_dist: np.ndarray = np.ones(len(sigma_split)), |
|
|
|
k_val: float = defaults[k], |
|
|
|
k_val: float = defaults[k], |
|
|
|
xy_dist: np.ndarray = np.ones(len(xy_split))): |
|
|
|
xy_dist: np.ndarray = np.ones(len(xy_split))): |
|
|
|
return cls(cells=[Cell.random(sigma_dist, k_val, xy_dist) for _ in range(n)]) |
|
|
|
return cls(cells=[Cell.random(phi_dist, theta_dist, sigma_dist, k_val, xy_dist) for _ in range(n)]) |
|
|
|
|
|
|
|
|
|
|
|
def get_response(self, phi_deg: float, theta_deg: float, coef: float = 4, use_sigmoid: bool = True) -> np.ndarray: |
|
|
|
def get_response(self, phi_deg: float, theta_deg: float, coef: float = 4, use_sigmoid: bool = True) -> np.ndarray: |
|
|
|
return (sigmoid if use_sigmoid else (lambda x: x))(np.array([cell.get_value(theta_deg, phi_deg) for cell in self.cells]) * coef) |
|
|
|
return (sigmoid if use_sigmoid else (lambda x: x))(np.array([cell.get_value(theta_deg, phi_deg) for cell in self.cells]) * coef) |
|
|
|