from __future__ import annotations
import typing
from dataclasses import dataclass , field
from functools import partial
import matplotlib as mpl
import sympy as sp
import numpy as np
from utils import get_orientation_phase_grid
sp . init_printing ( )
k , x0 , y0 , phi_rf , theta_rf , sigma_x , sigma_y , sigma , x , y , theta_grating , phi_grating = sp . symbols ( r ' k x_0 y_0 \ phi_ {rf} \ theta_ {rf} \ sigma_x \ sigma_y \ sigma x y \ theta_ {grating} \ phi_ {grating} ' )
defaults = {
k : 6 ,
sigma : 0.2 ,
phi_rf : sp . pi / 2 ,
phi_grating : sp . pi / 2 ,
theta_grating : 0 ,
sigma : 1 ,
x0 : 0 , y0 : 0
}
sigma_x = sigma_y = sigma
grating_f = sp . cos ( k * ( x - x0 ) * sp . cos ( theta_grating ) + k * ( y - y0 ) * sp . sin ( theta_grating ) + phi_grating )
receptive_field = 1 / ( 2 * sp . pi * sigma * sigma ) * sp . exp ( - ( x * * 2 + y * * 2 ) / ( 2 * sigma * * 2 ) ) * sp . cos (
k * x * sp . cos ( theta_rf ) + k * y * sp . sin ( theta_rf ) + phi_rf )
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos(
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta)))
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta) * 4) * sp.exp(-4 * k ** 2 * sigma ** 2) * sp.cos(
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta)))
p = ( 1 / 2 ) * sp . exp ( - k * k * sigma * sigma ) * (
sp . exp ( - k * k * sigma * sigma * sp . sin ( theta_grating + theta_rf ) ) * sp . cos ( phi_grating + phi_rf + 2 * k / ( sigma * sigma ) * (
x0 * sp . cos ( theta_rf ) + y0 * sp . sin ( theta_grating )
+ x0 * sp . sin ( theta_rf ) + y0 * sp . cos ( theta_grating )
) ) +
sp . exp ( k * k * sigma * sigma * sp . sin ( theta_grating + theta_rf ) ) * sp . cos ( phi_grating - phi_rf + 2 * k / ( sigma * sigma ) * (
- x0 * sp . cos ( theta_rf ) - y0 * sp . sin ( theta_rf )
+ x0 * sp . sin ( theta_rf ) + y0 * sp . cos ( theta_grating )
) ) )
sigma_split = np . arange ( 0.1 , 1 , 0.05 )
k_split = np . arange ( 0.2 , 6 , 0.2 )
xy_split = np . arange ( - 1 , 1 , 0.05 )
phi_split = np . arange ( 0 , 2 * np . pi , np . pi / 100 )
theta_split = np . arange ( 0 , np . pi , np . pi / 100 )
def sigmoid ( x ) :
return 1 / ( 1 + np . exp ( - x ) )
@dataclass
class Cell :
phi_val : float
theta_val : float
sigma_val : float = defaults [ sigma ]
x0_val : float = defaults [ x0 ]
y0_val : float = defaults [ y0 ]
k_val : float = defaults [ k ]
@classmethod
def random ( cls ,
phi_dist : np . ndarray = np . ones ( len ( phi_split ) ) ,
theta_dist : np . ndarray = np . ones ( len ( theta_split ) ) ,
sigma_dist : np . ndarray = np . ones ( len ( sigma_split ) ) ,
k_val : float = defaults [ k ] ,
xy_dist : np . ndarray = np . ones ( len ( xy_split ) ) ) :
return cls (
sigma_val = 1 , # np.random.choice(sigma_split, p=sigma_dist / np.sum(sigma_dist)),
phi_val = np . random . choice ( phi_split , p = phi_dist / np . sum ( phi_dist ) ) ,
theta_val = np . random . choice ( theta_split , p = theta_dist / np . sum ( theta_dist ) ) ,
x0_val = 0 , # np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)),
y0_val = 0 , # np.random.choice(xy_split, p=xy_dist / np.sum(xy_dist)),
k_val = k_val ,
)
@property
def sympy_func ( self ) - > sp . Expr :
return receptive_field \
. subs ( sigma , self . sigma_val ) . subs ( x0 , self . x0_val ) . subs ( y0 , self . y0_val ) \
. subs ( k , self . k_val ) . subs ( phi_rf , self . phi_val ) . subs ( theta_rf , self . theta_val )
def get_tuning_function ( self ) - > typing . Callable [ [ np . ndarray , np . ndarray ] , np . ndarray ] :
"""
Get the tuning sympy function as a numpy lambda function of theta and phi .
: return : a function ( theta , phi ) - > value
"""
return sp . lambdify (
( theta_grating , phi_grating ) ,
p . subs ( sigma , self . sigma_val ) . subs ( x0 , self . x0_val ) \
. subs ( y0 , self . y0_val ) . subs ( k , self . k_val ) \
. subs ( phi_rf , self . phi_val ) . subs ( theta_rf , self . theta_val ) ,
' numpy ' )
def get_value ( self , theta_deg : float , phi_deg : float ) - > float :
return float ( self . get_tuning_function ( ) ( theta_deg * np . pi / 180 , phi_deg * np . pi / 180 ) )
def get_tuning_plot ( self , theta_step_deg : float , phi_step_deg : float ) - > np . ndarray :
grid = get_orientation_phase_grid ( theta_step_deg , phi_step_deg )
return self . get_tuning_function ( ) ( grid [ : , : , 0 ] , grid [ : , : , 1 ] )
@dataclass
class Grating :
phi_val : float = defaults [ phi_grating ]
theta_val : float = defaults [ theta_grating ]
k_val : float = defaults [ k ]
@property
def sympy_func ( self ) - > sp . Expr :
return grating_f . subs ( k , self . k_val ) . subs ( phi_grating , self . phi_val ) . subs ( theta_grating , self . theta_val )
@dataclass
class Population :
cells : typing . List [ Cell ] = field ( default_factory = list )
@property
def response_func ( self ) - > typing . Callable [ [ float , float ] , np . ndarray ] :
"""
Use sp . lambdify and the expression to generate the necessary function .
: return : a function ( phi , theta ) - > responses
"""
return partial (
sp . lambdify (
( x0 , y0 , k , sigma , phi_rf , theta_rf , phi_grating , theta_grating , ) ,
p , ' numpy ' ) ,
0 , 0 , np . array ( [ cell . k_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . sigma_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . phi_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . theta_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
)
@classmethod
def random ( cls , n : int ,
phi_dist : np . ndarray = np . ones ( len ( phi_split ) ) ,
theta_dist : np . ndarray = np . ones ( len ( theta_split ) ) ,
sigma_dist : np . ndarray = np . ones ( len ( sigma_split ) ) ,
k_val : float = defaults [ k ] ,
xy_dist : np . ndarray = np . ones ( len ( xy_split ) ) ) :
return cls ( cells = [ Cell . random ( phi_dist , theta_dist , sigma_dist , k_val , xy_dist ) for _ in range ( n ) ] )
def get_response ( self , phi_deg : float , theta_deg : float , coef : float = 4 , use_sigmoid : bool = True ) - > np . ndarray :
return ( sigmoid if use_sigmoid else ( lambda x : x ) ) ( np . array ( [ cell . get_value ( theta_deg , phi_deg ) for cell in self . cells ] ) * coef )
def sample_responses (
self , n : int , noise_sigma : float = 0 , coef : float = 2 , use_sigmoid : bool = True ,
custom_grid : typing . Optional [ np . ndarray ] = None
) - > np . ndarray :
return np . array ( [
np . array ( [ self . get_response ( phi_deg , theta_deg % 180 , coef = coef , use_sigmoid = use_sigmoid ) , np . ones ( len ( self . cells ) ) * phi_deg ,
np . ones ( len ( self . cells ) ) * theta_deg ] ) . swapaxes ( 0 , 1 )
for phi_deg , theta_deg in ( np . random . uniform ( 0 , 360 , ( n , 2 ) ) if custom_grid is None else custom_grid )
] ) + np . random . normal (
0 , [ noise_sigma , 0 , 0 ] ,
( n if custom_grid is None else len ( custom_grid ) , len ( self . cells ) , 3 ) )