Source code for scope.ccf

"""
Calculates the cross-correlation function (and log likelihood function from the Brogi & Line 2019 mapping)
"""

from functools import wraps

import jax
import jax.numpy as jnp


[docs] def jit(f): @wraps(f) def wrapper(*args, **kwargs): return jax.jit(f(*args, **kwargs)) return wrapper
# @jit
[docs] def calc_ccf(model_flux, data_arr_slice, n_pixel): """ Calculates the CCF between a model and a data slice. Inputs ------ :model_flux: (jnp.ndarray) The model flux, normalized and such. :data_arr_slice: (jnp.ndarray) The data slice, normalized and such. :n_pixels: (int) The number of pixels in the data slice. Outputs ------- :logl: (float) The log-likelihood of the data given the model. :CCF: (float) The CCF between the model and the data. """ model_vector = jnp.subtract( model_flux, jnp.vstack(jnp.mean(model_flux, axis=1)) ) # normalized and such variance_model = jnp.var(model_vector, axis=1) variance_data = jnp.var(data_arr_slice, axis=1) cross_variance = (model_vector * data_arr_slice).sum(axis=1) / n_pixel # now need to sum ccf = (cross_variance / jnp.sqrt(variance_data * variance_model)).sum() logl = ( -0.5 * n_pixel * jnp.log(variance_data + variance_model - 2.0 * cross_variance) ).sum() return logl, ccf
calc_ccf_map = jax.vmap(calc_ccf, in_axes=(0, None, None), out_axes=0)