"""
File to calculate the intensity map and broadening said map. and the velocity profile.
"""
from enum import Enum
from functools import partial
from typing import Literal, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from numba import njit
[docs]
class RotationalBroadeningError(Exception):
"""Custom exception for rotational broadening calculation errors."""
pass
ObservationType = Literal["emission", "transmission"]
[docs]
@jax.jit
def calculate_velocity_step(wavelengths: jnp.ndarray) -> float:
"""Calculate the velocity step size from wavelength array."""
delta_wavelengths = wavelengths[1:] - wavelengths[:-1]
wavelength_means = (wavelengths[1:] + wavelengths[:-1]) / 2.0
dRV = jnp.mean(2.0 * delta_wavelengths / wavelength_means) * 2.998e5
return dRV
[docs]
@jax.jit
def get_rotational_kernel(
v_sin_i: Union[float, jnp.ndarray],
wavelengths: jnp.ndarray,
is_transmission: bool = False,
) -> jnp.ndarray:
"""
Calculate rotational broadening kernel using JAX.
Args:
v_sin_i: Projected rotational velocity in km/s
wavelengths: Array of wavelength points (must be sorted)
is_transmission: If True, use transmission profile (1/pi/sqrt(1-x²)),
else emission profile (sqrt(1-x²))
Returns:
Normalized rotational broadening kernel
"""
dRV = calculate_velocity_step(wavelengths)
# Fixed kernel size for both modes
indices = jnp.arange(401) - 200
x_positions = indices * dRV / v_sin_i
# Calculate profile based on observation mode
profile = jnp.where(
is_transmission,
1.0 / (jnp.pi * jnp.sqrt(1 - x_positions**2)), # transmission profile
jnp.sqrt(1 - x_positions**2), # emission profile
)
# Apply mask for valid regions (|x| < 1)
kernel = jnp.where(jnp.abs(x_positions) < 1.0, profile, 0.0)
# Normalize the kernel
return kernel / jnp.sum(kernel)
[docs]
def get_rot_ker(
v_sin_i: Union[float, np.ndarray, jnp.ndarray],
wavelengths: Union[np.ndarray, jnp.ndarray, list],
observation: str = "emission",
) -> jnp.ndarray:
"""
Safe wrapper for get_rotational_kernel with automatic array conversion.
Args:
v_sin_i: Projected rotational velocity in km/s
wavelengths: Array of wavelength points
observation: Type of observation, either "emission" or "transmission"
Returns:
Normalized kernel array
Raises:
RotationalBroadeningError: If calculation fails or if invalid observation type
"""
# Validate observation type
if observation not in ["emission", "transmission"]:
raise RotationalBroadeningError(
f"observation must be 'emission' or 'transmission', got '{observation}'"
)
try:
# Convert inputs to JAX arrays if needed
if not isinstance(wavelengths, jnp.ndarray):
wavelengths = jnp.array(wavelengths)
if isinstance(v_sin_i, (list, np.ndarray)):
v_sin_i = jnp.array(v_sin_i)
# Basic input validation
if wavelengths.ndim != 1:
raise RotationalBroadeningError(
f"wavelengths must be 1D array, got shape {wavelengths.shape}"
)
if wavelengths.size < 2:
raise RotationalBroadeningError(
"wavelength array must have at least 2 points"
)
if isinstance(v_sin_i, (float, int)) and v_sin_i <= 0:
raise RotationalBroadeningError(f"v_sin_i must be positive, got {v_sin_i}")
# Calculate kernel
kernel = get_rotational_kernel(
v_sin_i, wavelengths, is_transmission=(observation == "transmission")
)
if jnp.any(jnp.isnan(kernel)):
raise RotationalBroadeningError("Kernel calculation failed - check inputs")
return kernel
except Exception as e:
if not isinstance(e, RotationalBroadeningError):
raise RotationalBroadeningError(f"Calculation failed: {str(e)}")
raise
# @njit
[docs]
def get_theta(lat, lon):
"""
Converts latitude and longitude to Theta, the angular distance across the disk.
Inputs
------
:lat: latitude in radians
:lon: longitude in radians
Outputs
-------
:res1: the angular distance across the disk.
"""
if np.any(np.abs(lat) > np.pi) or np.any(np.abs(lon) > np.pi):
raise ValueError("lat or lon is too large. are you inputting in radians?")
else:
return np.arccos(np.sqrt(1 - np.sin(lon) ** 2 - np.sin(lat) ** 2))
[docs]
def fix_nan_result(res):
"""
Fixes the result of a calculation that has NaNs in it.
Parameters
----------
res : float
The result of a calculation.
Returns
-------
float
The result of the calculation with NaNs fixed.
"""
if isinstance(res, np.ndarray) or isinstance(res, list):
res[np.isnan(res)] = 0.0
else:
if np.isnan(res):
res = 0.0
return res
# @njit
[docs]
def I_darken(lon, lat, epsilon):
"""
takes lon and lat in. spits out the value on the limb-darkened disk.
Inputs
------
:lon: longitude in radians
:lat: latitude in radians
:epsilon: limb-darkening coefficient. should be between 0 and 1.
Outputs
-------
:res1: the intensity at that point on the disk.
"""
if 0.0 <= epsilon <= 1.0:
theta = get_theta(lat, lon)
res1 = (1 - epsilon) + (epsilon * np.cos(theta))
res1 = fix_nan_result(res1)
return res1
else:
raise ValueError("epsilon should only be from 0 to 1")
# @njit
[docs]
def I_darken_disk(x, y, epsilon):
"""
takes x and y in. spits out the value on the limb-darkened disk.
Inputs
------
:x: x position on the disk
:y: y position on the disk
:epsilon: limb-darkening coefficient. should be between 0 and 1.
Outputs
-------
:res1: the intensity at that point on the disk.
"""
if x**2 + y**2 <= 1:
lon = np.arcsin(x)
lat = np.arcsin(y)
return I_darken(lon, lat, epsilon)
else:
return 0.0
[docs]
def I_darken_disk_integrated(x, y, epsilon):
"""
takes x and y in. spits out the value on the limb-darkened disk.
Inputs
------
:x: x position on the disk
:y: y position on the disk
:epsilon: limb-darkening coefficient. should be between 0 and 1.
Outputs
-------
:res1: the intensity at that point on the disk.
"""
lon = np.arcsin(x)
lat = np.arcsin(y)
dTheta = dx / np.cos(lon)
dPhi = dy / np.cos(lat)
return np.sum(
dTheta * dPhi * np.cos(lon) * np.cos(lat) * I_darken(lon, lat, epsilon)
)
[docs]
@njit
def gaussian_term(lon, lat, offset, sigma, amp):
"""
the gaussian term.
Inputs
------
:lat: latitude in radians
:lon: longitude in radians
:offset: the offset of the gaussian
:sigma: the sigma of the gaussian
:amp: the amplitude of the gaussian
Outputs
-------
:res1: the gaussian term.
"""
return gaussian_term_1d(lat, 0.0, sigma, amp) * np.exp(
-((lon - offset) ** 2) / (2 * sigma**2)
)
[docs]
@njit
def gaussian_term_1d(lat, offset, sigma, amp):
"""
the gaussian term. this time there's no longitude dependence!
Inputs
------
:lat: latitude in radians
:lon: longitude in radians
:offset: the offset of the gaussian
:sigma: the sigma of the gaussian
:amp: the amplitude of the gaussian
Outputs
-------
:res1: the gaussian term.
"""
return (
amp
* (1 / (sigma**2 * 2 * np.pi))
* np.exp(-((lat - offset) ** 2) / (2 * sigma**2))
)
#### below: the actual integrations
vl = 1
[docs]
@njit
def numerator_integral_no_sigma(vz, epsilon, n_samples=100):
"""
calculates the numerator of the line profile. this is the version without the exponent.
Parameters
----------
vz : float
The velocity of the star.
epsilon : float
The limb-darkening coefficient.
n_samples : int
The number of samples to use in the integration.
Returns
-------
total_res : float
The numerator of the line profile.
"""
# set bounds
lower_bound = 0
upper_bound = np.arcsin(np.sqrt(1 - np.square(vz / vl)))
# need to fix these bounds
lat_arr = np.linspace(lower_bound, upper_bound, n_samples)
d_lat = np.diff(lat_arr)[0]
total_res = 0.0
for lat in lat_arr:
res = (
(
(1 - epsilon)
+ epsilon * np.sqrt(1 - np.square(vz / vl) - np.square(np.sin(lat)))
)
* np.cos(lat)
* d_lat
)
if not np.isnan(res):
total_res += res
return total_res
[docs]
@njit
def numerator_no_sigma(vz, epsilon, n_samples=100):
"""
calculates the numerator of the line profile. this is the version without the exponent.
Parameters
----------
vz : float
The velocity of the star.
epsilon : float
The limb-darkening coefficient.
n_samples : int
The number of samples to use in the integration.
Returns
-------
res1 : float
The numerator of the line profile.
"""
res1 = 2
return res1 * numerator_integral_no_sigma(vz, epsilon, n_samples=n_samples)
# @njit
[docs]
def line_profile_no_exponent(epsilon, dRV, n_samples=int(1e4), model="igrins", vl=5):
"""
calculates the line profile for a given epsilon and dRV. this is the version without the exponent.
Parameters
----------
epsilon : float
The limb-darkening coefficient.
dRV : float
The width of the line profile.
n_samples : int
The number of samples to use in the integration.
model : str
vl : float
The equatorial rotational velocity of the star.
Returns
-------
big_resoluts : np.ndarray
The line profile.
"""
if model == "igrins":
n_vz = 2 * vl / dRV
else:
n_vz = 500
vzs = np.linspace(
-1, 1, int(n_vz)
) # need to set to 0 if off disk. this is actually vz / vl. can also step off.
big_resoluts = np.zeros(vzs.shape)
for i, vz in enumerate(vzs):
re_num = numerator_no_sigma(vz, epsilon, n_samples=n_samples)
# vz, epsilon, mu, sigma=.7, n_samples=100
# todo: check sensitivity to lat / lon gridding
re_denom = I_darken_disk_integrated(X, Y, epsilon)
# pdb.set_trace()
big_resoluts[i] = re_num / re_denom
# normalize
big_resoluts /= np.sum(big_resoluts)
return big_resoluts
[docs]
def convert_vz_to_lon(vz, R, sigma_jet, amp_jet, mu_jet):
"""
Converts vz to a longitude.
Parameters
----------
vz : float
The velocity along the line of sight.
R : float
The radius of the star.
sigma_jet : float
The width of the jet.
amp_jet : float
The amplitude of the jet.
mu_jet : float
The center of the jet.
"""
gauss_term = gaussian_term_1d(vz, sigma_jet, amp_jet, mu_jet)
return np.arcsin(vz / (R * gauss_term))
[docs]
def convert_lon_to_vz(lon, R, sigma_jet, amp_jet, mu_jet):
"""
Converts longitude to a vz
Parameters
----------
lon : float
The longitude.
R : float
The radius of the planet.
sigma_jet : float
The width of the jet.
amp_jet : float
The amplitude of the jet.
mu_jet : float
The center of the jet.
"""
vz = (
R * np.sin(lon) * gaussian_term_1d(lon, sigma_jet, amp_jet, mu_jet)
) # todo: check another factor of R?
return vz
# @njit
[docs]
def broaden_spectrum(
wav, spectrum_flux, epsilon, n_samples=int(1e4), vl=5, model="igrins"
):
"""
Broadens a spectrum by a line profile.
Parameters
----------
wl_model : array-like
The wavelength model.
spectrum_flux : array-like
The flux of the spectrum.
epsilon : float
The spot contrast.
n_samples : int
The number of samples to use in the integration.
vl : float
The rotational velocity of the object.
Returns
-------
array-like
The broadened spectrum.
"""
# pdb.set_trace()
dRV = (
np.mean(
np.array(
[np.mean(np.diff(wav) / wav[1:]), np.mean(np.diff(wav) / wav[:-1])]
)
)
* const_c
/ 1e3
)
profile = line_profile_no_exponent(
epsilon, dRV, n_samples=n_samples, vl=vl, model=model
)
return np.convolve(spectrum_flux, profile, mode="same")
x = np.linspace(-1, 1, 80)
y = np.linspace(-1, 1, 80)
X, Y = np.meshgrid(x, y)
dx = np.diff(x)[0]
dy = np.diff(y)[0]
if __name__ == "__main__":
x = np.linspace(-1, 1, 80)
y = np.linspace(-1, 1, 80)
vl = 5
X, Y = np.meshgrid(x, y)
dx = np.diff(x)[0]
dy = np.diff(y)[0]
lon = np.linspace(-90, 90, 80)
lat = np.linspace(-90, 90, 80)
lon = np.radians(lon)
lat = np.radians(lat)
LON, LAT = np.meshgrid(lon, lat)
x = np.linspace(-1, 1, 80)
y = np.linspace(-1, 1, 80)
X, Y = np.meshgrid(x, y)
res = I_darken(LON, LAT, 0.1)
res_disk = I_darken_disk(X, Y, 0.9)