from __future__ import annotations
import typing
from dataclasses import dataclass , field
from itertools import chain
from functools import partial
import sympy as sp
import numpy as np
from utils import get_orientation_phase_grid
sp . init_printing ( )
k , x0 , y0 , phi_rf , theta_rf , sigma_x , sigma_y , sigma , x , y , theta_grating , phi_grating = sp . symbols ( r ' k x_0 y_0 \ phi_ {rf} \ theta_ {rf} \ sigma_x \ sigma_y \ sigma x y \ theta_ {grating} \ phi_ {grating} ' )
defaults = {
k : 6 ,
sigma : 1 ,
phi_rf : sp . pi / 2 ,
phi_grating : sp . pi / 2 ,
theta_grating : 0 ,
sigma : 1 ,
x0 : 0 , y0 : 0
}
sigma_x = sigma_y = sigma
grating_f = sp . cos ( k * ( x - x0 ) * sp . cos ( theta_grating ) + k * ( y - y0 ) * sp . sin ( theta_grating ) + phi_grating )
receptive_field = 1 / ( 2 * sp . pi * sigma * sigma ) * sp . exp ( - ( x * * 2 + y * * 2 ) / ( 2 * sigma * * 2 ) ) * sp . cos (
k * x * sp . cos ( theta_rf ) + k * y * sp . sin ( theta_rf ) + phi_rf )
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta)) * sp.exp(k ** 2 * (1 + sp.cos(theta) ** 2) / 2) * sp.cos(
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta)))
# p = sp.cosh(k ** 2 * sigma ** 2 * sp.cos(theta) * 4) * sp.exp(-4 * k ** 2 * sigma ** 2) * sp.cos(
# phi - k * (x0 * sp.cos(theta) + y0 * sp.sin(theta)))
p = ( 1 / 2 ) * sp . exp ( - k * k * sigma * sigma ) * (
sp . exp ( - k * k * sigma * sigma * sp . sin ( theta_grating + theta_rf ) ) * sp . cos ( phi_grating + phi_rf + 2 * k / ( sigma * sigma ) * (
x0 * sp . cos ( theta_rf ) + y0 * sp . sin ( theta_grating )
+ x0 * sp . sin ( theta_rf ) + y0 * sp . cos ( theta_grating )
) ) +
sp . exp ( k * k * sigma * sigma * sp . sin ( theta_grating + theta_rf ) ) * sp . cos ( phi_grating - phi_rf + 2 * k / ( sigma * sigma ) * (
- x0 * sp . cos ( theta_rf ) - y0 * sp . sin ( theta_rf )
+ x0 * sp . sin ( theta_rf ) + y0 * sp . cos ( theta_grating )
) ) )
sigma_split = np . arange ( 0.1 , 1 , 0.05 )
k_split = np . arange ( 0.2 , 6 , 0.2 )
xy_split = np . arange ( - 1 , 1 , 0.05 )
phi_split = np . arange ( 0 , 2 * np . pi , np . pi / 100 )
theta_split = np . arange ( 0 , np . pi , np . pi / 100 )
# The second option is (the distribution function, the function that takes the size and returns the step and the starting point)
Distribution = typing . Union [ float , typing . Dict [ float , float ] , typing . Tuple [ typing . Callable [ [ float ] , float ] , typing . Callable [ [ int ] , typing . Tuple [ float , float ] ] ] ]
def sample_distribution ( distribution : Distribution , size : int = 1 ) - > np . ndarray :
if isinstance ( distribution , float ) or isinstance ( distribution , int ) :
return np . array ( [ float ( distribution ) ] * size )
if isinstance ( distribution , dict ) :
return np . random . choice ( list ( distribution . keys ( ) ) , size , p = list ( distribution . values ( ) ) )
elif isinstance ( distribution , tuple ) :
step , start = distribution [ 1 ] ( size )
res = [ start ]
for i in range ( size ) :
res . append ( res [ - 1 ] + step / distribution [ 0 ] ( res [ - 1 ] ) )
return np . array ( res )
else :
raise ValueError ( f ' Unknown distribution type: { type ( distribution ) } ' )
def get_uniform_dist ( start : float , stop : float ) - > Distribution :
return ( lambda _x : 1 , lambda size : ( ( stop - start ) / ( size - 1 or 1 ) , start ) )
phi_dist_uni = get_uniform_dist ( 0 , 2 * np . pi )
theta_dist_uni = get_uniform_dist ( 0 , np . pi )
def sigmoid ( x ) :
return 1 / ( 1 + np . exp ( - x ) )
@dataclass
class Cell :
phi_val : float
theta_val : float
sigma_val : float = defaults [ sigma ]
x0_val : float = defaults [ x0 ]
y0_val : float = defaults [ y0 ]
k_val : float = defaults [ k ]
@property
def sympy_func ( self ) - > sp . Expr :
return receptive_field \
. subs ( sigma , self . sigma_val ) . subs ( x0 , self . x0_val ) . subs ( y0 , self . y0_val ) \
. subs ( k , self . k_val ) . subs ( phi_rf , self . phi_val ) . subs ( theta_rf , self . theta_val )
def get_tuning_function ( self ) - > typing . Callable [ [ np . ndarray , np . ndarray ] , np . ndarray ] :
"""
Get the tuning sympy function as a numpy lambda function of theta and phi .
: return : a function ( theta , phi ) - > value
"""
return sp . lambdify (
( theta_grating , phi_grating ) ,
p . subs ( sigma , self . sigma_val ) . subs ( x0 , self . x0_val ) \
. subs ( y0 , self . y0_val ) . subs ( k , self . k_val ) \
. subs ( phi_rf , self . phi_val ) . subs ( theta_rf , self . theta_val ) ,
' numpy ' )
def get_value ( self , theta_deg : float , phi_deg : float ) - > float :
return float ( self . get_tuning_function ( ) ( theta_deg * np . pi / 180 , phi_deg * np . pi / 180 ) )
def get_tuning_plot ( self , theta_step_deg : float , phi_step_deg : float ) - > np . ndarray :
grid = get_orientation_phase_grid ( theta_step_deg , phi_step_deg )
return self . get_tuning_function ( ) ( grid [ : , : , 0 ] , grid [ : , : , 1 ] )
@dataclass
class Grating :
phi_val : float = defaults [ phi_grating ]
theta_val : float = defaults [ theta_grating ]
k_val : float = defaults [ k ]
@property
def sympy_func ( self ) - > sp . Expr :
return grating_f . subs ( k , self . k_val ) . subs ( phi_grating , self . phi_val ) . subs ( theta_grating , self . theta_val )
@dataclass
class Population :
cells : typing . List [ Cell ] = field ( default_factory = list )
@property
def response_func ( self ) - > typing . Callable [ [ float , float ] , np . ndarray ] :
"""
Use sp . lambdify and the expression to generate the necessary function .
: return : a function ( phi , theta ) - > responses
"""
return partial (
sp . lambdify (
( x0 , y0 , k , sigma , phi_rf , theta_rf , phi_grating , theta_grating , ) ,
p , ' numpy ' ) ,
0 , 0 , np . array ( [ cell . k_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . sigma_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . phi_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
np . array ( [ cell . theta_val for cell in self . cells ] ) . reshape ( ( - 1 , 1 ) ) ,
)
@classmethod
def sample ( cls , n_orient : int , n_phases : int ,
phi_dist : Distribution = phi_dist_uni ,
theta_dist : Distribution = theta_dist_uni ,
sigma_dist : Distribution = defaults [ sigma ] ,
k_val : float = defaults [ k ] ,
xy_dist : Distribution = get_uniform_dist ( - 5 , 5 ) ) :
return cls ( cells = [
Cell ( phi_val = phi_val , theta_val = theta_val , sigma_val = sigma_val , x0_val = x0_val * 0 , y0_val = y0_val * 0 , k_val = k_val )
for phi_val , theta_val , sigma_val , x0_val , y0_val in zip (
list ( sample_distribution ( phi_dist , n_phases ) ) * n_orient ,
list ( chain ( * [ [ x ] * n_phases for x in sample_distribution ( theta_dist , n_orient ) ] ) ) ,
sample_distribution ( sigma_dist , n_phases * n_orient ) ,
sample_distribution ( xy_dist , n_phases * n_orient ) ,
sample_distribution ( xy_dist , n_phases * n_orient ) )
] )
def get_response ( self , phi_deg : float , theta_deg : float , coef : float = 4 , use_sigmoid : bool = True ) - > np . ndarray :
return ( sigmoid if use_sigmoid else ( lambda x : x ) ) ( np . array ( [ cell . get_value ( theta_deg , phi_deg ) for cell in self . cells ] ) * coef )
def sample_responses (
self , n : int , noise_sigma : float = 0 , coef : float = 2 , use_sigmoid : bool = True ,
custom_grid : typing . Optional [ np . ndarray ] = None
) - > np . ndarray :
return np . array ( [
np . array ( [ self . get_response ( phi_deg , theta_deg % 180 , coef = coef , use_sigmoid = use_sigmoid ) , np . ones ( len ( self . cells ) ) * phi_deg ,
np . ones ( len ( self . cells ) ) * theta_deg ] ) . swapaxes ( 0 , 1 )
for phi_deg , theta_deg in ( np . random . uniform ( 0 , 360 , ( n , 2 ) ) if custom_grid is None else custom_grid )
] ) + np . random . normal (
0 , [ noise_sigma , 0 , 0 ] ,
( n if custom_grid is None else len ( custom_grid ) , len ( self . cells ) , 3 ) )