from __future__ import annotations import typing from dataclasses import dataclass, field from itertools import chain from functools import partial import sympy as sp import numpy as np from utils import get_orientation_phase_grid sp.init_printing() k, x0, y0, phi_rf, theta_rf, sigma_x, sigma_y, sigma, x, y, theta_grating, phi_grating = sp.symbols(r'k x_0 y_0 \phi_{rf} \theta_{rf} \sigma_x \sigma_y \sigma x y \theta_{grating} \phi_{grating}') defaults = { k: 6, sigma: 1, phi_rf: sp.pi / 2, phi_grating: sp.pi / 2, theta_grating: 0, sigma: 1, x0: 0, y0: 0 } sigma_x = sigma_y = sigma grating_f = sp.cos(k * (x - x0) * sp.cos(theta_grating) + k * (y - y0) * sp.sin(theta_grating) + phi_grating) receptive_field = 1 / (2 * sp.pi * sigma * sigma) * sp.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) * sp.cos( k * x * sp.cos(theta_rf) + k * y * sp.sin(theta_rf) + phi_rf) # p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos( # phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) # p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta) * 4) * sp.exp(-4 * k ** 2 * sigma ** 2) * sp.cos( # phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta))) p = (1 / 2) * sp.exp(-k*k*sigma*sigma) * ( sp.exp(-k*k*sigma*sigma*sp.sin(theta_grating + theta_rf)) * sp.cos(phi_grating + phi_rf + 2 * k / (sigma * sigma) * ( x0 * sp.cos(theta_rf) + y0 * sp.sin(theta_grating) + x0 * sp.sin(theta_rf) + y0 * sp.cos(theta_grating) )) + sp.exp( k*k*sigma*sigma*sp.sin(theta_grating + theta_rf)) * sp.cos(phi_grating - phi_rf + 2 * k / (sigma * sigma) * ( -x0 * sp.cos(theta_rf) - y0 * sp.sin(theta_rf) + x0 * sp.sin(theta_rf) + y0 * sp.cos(theta_grating) ))) sigma_split = np.arange(0.1, 1, 0.05) k_split = np.arange(0.2, 6, 0.2) xy_split = np.arange(-1, 1, 0.05) phi_split = np.arange(0, 2 * np.pi, np.pi / 100) theta_split = np.arange(0, np.pi, np.pi / 100) # The second option is (the distribution function, the function that takes the size and returns the step and the starting point) Distribution = typing.Union[float, typing.Dict[float, float], typing.Tuple[typing.Callable[[float], float], typing.Callable[[int], typing.Tuple[float, float]]]] def sample_distribution(distribution: Distribution, size: int = 1) -> np.ndarray: if isinstance(distribution, float) or isinstance(distribution, int): return np.array([float(distribution)] * size) if isinstance(distribution, dict): return np.random.choice(list(distribution.keys()), size, p=list(distribution.values())) elif isinstance(distribution, tuple): step, start = distribution[1](size) res = [start] for i in range(size): res.append(res[-1] + step / distribution[0](res[-1])) return np.array(res) else: raise ValueError(f'Unknown distribution type: {type(distribution)}') def get_uniform_dist(start: float, stop: float) -> Distribution: return (lambda _x: 1, lambda size: ((stop - start) / (size - 1 or 1), start)) phi_dist_uni = get_uniform_dist(0, 2 * np.pi) theta_dist_uni = get_uniform_dist(0, np.pi) def sigmoid(x): return 1 / (1 + np.exp(-x)) @dataclass class Cell: phi_val: float theta_val: float sigma_val: float = defaults[sigma] x0_val: float = defaults[x0] y0_val: float = defaults[y0] k_val: float = defaults[k] @property def sympy_func(self) -> sp.Expr: return receptive_field\ .subs(sigma, self.sigma_val).subs(x0, self.x0_val).subs(y0, self.y0_val)\ .subs(k, self.k_val).subs(phi_rf, self.phi_val).subs(theta_rf, self.theta_val) def get_tuning_function(self) -> typing.Callable[[np.ndarray, np.ndarray], np.ndarray]: """ Get the tuning sympy function as a numpy lambda function of theta and phi. :return: a function (theta, phi) -> value """ return sp.lambdify( (theta_grating, phi_grating), p.subs(sigma, self.sigma_val).subs(x0, self.x0_val)\ .subs(y0, self.y0_val).subs(k, self.k_val)\ .subs(phi_rf, self.phi_val).subs(theta_rf, self.theta_val), 'numpy') def get_value(self, theta_deg: float, phi_deg: float) -> float: return float(self.get_tuning_function()(theta_deg * np.pi / 180, phi_deg * np.pi / 180)) def get_tuning_plot(self, theta_step_deg: float, phi_step_deg: float) -> np.ndarray: grid = get_orientation_phase_grid(theta_step_deg, phi_step_deg) return self.get_tuning_function()(grid[:, :, 0], grid[:, :, 1]) @dataclass class Grating: phi_val: float = defaults[phi_grating] theta_val: float = defaults[theta_grating] k_val: float = defaults[k] @property def sympy_func(self) -> sp.Expr: return grating_f.subs(k, self.k_val).subs(phi_grating, self.phi_val).subs(theta_grating, self.theta_val) @dataclass class Population: cells: typing.List[Cell] = field(default_factory=list) @property def response_func(self) -> typing.Callable[[float, float], np.ndarray]: """ Use sp.lambdify and the expression to generate the necessary function. :return: a function (phi, theta) -> responses """ return partial( sp.lambdify( (x0, y0, k, sigma, phi_rf, theta_rf, phi_grating, theta_grating, ), p, 'numpy'), 0, 0, np.array([cell.k_val for cell in self.cells]).reshape((-1, 1)), np.array([cell.sigma_val for cell in self.cells]).reshape((-1, 1)), np.array([cell.phi_val for cell in self.cells]).reshape((-1, 1)), np.array([cell.theta_val for cell in self.cells]).reshape((-1, 1)), ) @classmethod def sample(cls, n_orient: int, n_phases: int, phi_dist: Distribution = phi_dist_uni, theta_dist: Distribution = theta_dist_uni, sigma_dist: Distribution = defaults[sigma], k_val: float = defaults[k], xy_dist: Distribution = get_uniform_dist(-5, 5)): return cls(cells=[ Cell(phi_val=phi_val, theta_val=theta_val, sigma_val=sigma_val, x0_val=x0_val * 0, y0_val=y0_val * 0, k_val=k_val) for phi_val, theta_val, sigma_val, x0_val, y0_val in zip( list(sample_distribution(phi_dist, n_phases)) * n_orient, list(chain(*[[x] * n_phases for x in sample_distribution(theta_dist, n_orient)])), sample_distribution(sigma_dist, n_phases * n_orient), sample_distribution(xy_dist, n_phases * n_orient), sample_distribution(xy_dist, n_phases * n_orient)) ]) def get_response(self, phi_deg: float, theta_deg: float, coef: float = 4, use_sigmoid: bool = True) -> np.ndarray: return (sigmoid if use_sigmoid else (lambda x: x))(np.array([cell.get_value(theta_deg, phi_deg) for cell in self.cells]) * coef) def sample_responses( self, n: int, noise_sigma: float = 0, coef: float = 2, use_sigmoid: bool = True, custom_grid: typing.Optional[np.ndarray] = None ) -> np.ndarray: return np.array([ np.array([self.get_response(phi_deg, theta_deg % 180, coef=coef, use_sigmoid=use_sigmoid), np.ones(len(self.cells)) * phi_deg, np.ones(len(self.cells)) * theta_deg]).swapaxes(0, 1) for phi_deg, theta_deg in (np.random.uniform(0, 360, (n, 2)) if custom_grid is None else custom_grid) ]) + np.random.normal( 0, [noise_sigma, 0, 0], (n if custom_grid is None else len(custom_grid), len(self.cells), 3))