import re
import numpy as np
import astropy.units as u
from astropy.wcs import WCSSUB_SPECTRAL
from specutils import Spectrum1D, SpectralRegion
from specutils.utils import QuantityModel
from traitlets import Any, Bool, List, Unicode, observe
from glue.core.data import Data
from glue.core.subset import Subset, RangeSubsetState, OrState, AndState
from glue.core.link_helpers import LinkSame
from glue.core.message import SubsetDeleteMessage, SubsetUpdateMessage
from jdaviz.core.events import AddDataMessage, RemoveDataMessage, SnackbarMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import PluginTemplateMixin, SpectralSubsetSelectMixin
from jdaviz.core.custom_traitlets import IntHandleEmpty
from jdaviz.configs.default.plugins.model_fitting.fitting_backend import fit_model_to_spectrum
from jdaviz.configs.default.plugins.model_fitting.initializers import (MODELS,
initialize,
get_model_parameters)
__all__ = ['ModelFitting']
class _EmptyParam:
def __init__(self, value, unit=None):
self.value = value
self.unit = unit
self.quantity = u.Quantity(self.value,
self.unit if self.unit is not None else u.dimensionless_unscaled)
[docs]@tray_registry('g-model-fitting', label="Model Fitting")
class ModelFitting(PluginTemplateMixin, SpectralSubsetSelectMixin):
dialog = Bool(False).tag(sync=True)
template_file = __file__, "model_fitting.vue"
dc_items = List([]).tag(sync=True)
form_valid_data_selection = Bool(False).tag(sync=True)
form_valid_model_component = Bool(False).tag(sync=True)
selected_data = Unicode("").tag(sync=True)
spectral_min = Any().tag(sync=True)
spectral_max = Any().tag(sync=True)
spectral_unit = Unicode().tag(sync=True)
model_label = Unicode().tag(sync=True)
cube_fit = Bool(False).tag(sync=True)
temp_name = Unicode().tag(sync=True)
temp_model = Unicode().tag(sync=True)
model_equation = Unicode().tag(sync=True)
eq_error = Bool(False).tag(sync=True)
component_models = List([]).tag(sync=True)
display_order = Bool(False).tag(sync=True)
poly_order = IntHandleEmpty(0).tag(sync=True)
# add/replace results for "fit"
add_replace_results = Bool(True).tag(sync=True)
# selected_viewer for "apply to cube"
# NOTE: this is currently cubeviz-specific so will need to be updated
# to be config-specific if using within other viewer configurations.
viewer_to_id = {'Left': 'cubeviz-0', 'Center': 'cubeviz-1', 'Right': 'cubeviz-2'}
viewers = List(['None', 'Left', 'Center', 'Right']).tag(sync=True)
selected_viewer = Unicode('None').tag(sync=True)
available_models = List(list(MODELS.keys())).tag(sync=True)
def __init__(self, *args, **kwargs):
self._spectrum1d = None
super().__init__(*args, **kwargs)
self._units = {}
self.n_models = 0
self._fitted_model = None
self._fitted_spectrum = None
self.component_models = []
self._initialized_models = {}
self._display_order = False
self.model_label = "Model"
self._window = None
self._original_mask = None
if self.app.state.settings.get("configuration") == "cubeviz":
self.cube_fit = True
self.hub.subscribe(self, AddDataMessage,
handler=self._on_viewer_data_changed)
self.hub.subscribe(self, RemoveDataMessage,
handler=self._on_viewer_data_changed)
self.hub.subscribe(self, SubsetUpdateMessage,
handler=self._on_viewer_data_changed)
self.hub.subscribe(self, SubsetDeleteMessage,
handler=self._on_viewer_data_changed)
def _on_viewer_data_changed(self, msg=None):
"""
Callback method for when data is added or removed from a viewer, or
when a subset is created, deleted, or updated. This method receives
a glue message containing viewer information in the case of the former
set of events, and updates the available data list displayed to the
user.
Notes
-----
We do not attempt to parse any data at this point, at it can cause
visible lag in the application.
Parameters
----------
msg : `glue.core.Message`
The glue message passed to this callback method.
"""
self._viewer_id = self.app._viewer_item_by_reference(
'spectrum-viewer').get('id')
viewer = self.app.get_viewer('spectrum-viewer')
self.dc_items = [layer_state.layer.label
for layer_state in viewer.state.layers
if ((not isinstance(layer_state.layer, Subset)
or not isinstance(layer_state.layer.subset_state,
(RangeSubsetState, OrState, AndState)))
and layer_state.layer.label not in self.app.fitted_models.keys())]
def _param_units(self, param, model_type=None):
"""Helper function to handle units that depend on x and y"""
y_params = ["amplitude", "amplitude_L", "intercept", "scale"]
if param == "slope":
return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"]))
elif model_type == 'Polynomial1D':
# param names are all named cN, where N is the order
order = int(float(param[1:]))
return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"])**order)
elif param == "temperature":
return str(u.K)
elif param == "scale" and model_type == "BlackBody":
return str("")
return self._units["y"] if param in y_params else self._units["x"]
def _update_parameters_from_fit(self):
"""Insert the results of the model fit into the component_models"""
for m in self.component_models:
name = m["id"]
if hasattr(self._fitted_model, "submodel_names"):
if name in self._fitted_model.submodel_names:
m_fit = self._fitted_model[name]
else:
continue
elif self._fitted_model.name == name:
m_fit = self._fitted_model
else:
# then the component was not in the fitted model
continue
temp_params = []
for i in range(0, len(m_fit.parameters)):
temp_param = [x for x in m["parameters"] if x["name"] ==
m_fit.param_names[i]]
temp_param[0]["value"] = m_fit.parameters[i]
temp_params += temp_param
m["parameters"] = temp_params
# Trick traitlets into updating the displayed values
component_models = self.component_models
self.component_models = []
self.component_models = component_models
def _update_parameters_from_QM(self):
"""
Parse out result parameters from a QuantityModel, which isn't
subscriptable with model name
"""
if hasattr(self._fitted_model, "submodel_names"):
submodel_names = self._fitted_model.submodel_names
submodels = True
else:
submodel_names = [self._fitted_model.name]
submodels = False
fit_params = self._fitted_model.parameters
param_names = self._fitted_model.param_names
for i in range(len(submodel_names)):
name = submodel_names[i]
m = [x for x in self.component_models if x["id"] == name][0]
temp_params = []
if submodels:
idxs = [j for j in range(len(param_names)) if
int(param_names[j][-1]) == i]
else:
idxs = [j for j in range(len(param_names))]
# This is complicated by needing to handle parameter names that
# have underscores in them, since QuantityModel adds an underscore
# and integer to indicate to which model a parameter belongs
for idx in idxs:
if submodels:
temp_param = [x for x in m["parameters"] if x["name"] ==
"_".join(param_names[idx].split("_")[0:-1])]
else:
temp_param = [x for x in m["parameters"] if x["name"] ==
param_names[idx]]
temp_param[0]["value"] = fit_params[idx]
temp_params += temp_param
m["parameters"] = temp_params
# Trick traitlets into updating the displayed values
component_models = self.component_models
self.component_models = []
self.component_models = component_models
def _update_initialized_parameters(self):
# If the user changes a parameter value, we need to change it in the
# initialized model
for m in self.component_models:
name = m["id"]
for param in m["parameters"]:
quant_param = u.Quantity(param["value"], param["unit"])
setattr(self._initialized_models[name], param["name"],
quant_param)
def _warn_if_no_equation(self):
if self.model_equation == "" or self.model_equation is None:
example = "+".join([m["id"] for m in self.component_models])
snackbar_message = SnackbarMessage(
f"Error: a model equation must be defined, e.g. {example}",
color='error',
sender=self)
self.hub.broadcast(snackbar_message)
return True
else:
return False
@observe("selected_data")
def _selected_data_changed(self, event):
"""
Callback method for when the user has selected data from the drop down
in the front-end. It is here that we actually parse and create a new
data object from the selected data. From this data object, unit
information is scraped, and the selected spectrum is stored for later
use in fitting.
Parameters
----------
event : str
IPyWidget callback event object. In this case, represents the data
label of the data collection object selected by the user.
"""
selected_spec = self.app.get_data_from_viewer("spectrum-viewer",
data_label=event['new'])
# Replace NaNs from collapsed Spectrum1D in Cubeviz
# (won't affect calculations because these locations are masked)
selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0
# Save original mask so we can reset after applying a subset mask
self._original_mask = selected_spec.mask
if self._units == {}:
self._units["x"] = str(
selected_spec.spectral_axis.unit)
self._units["y"] = str(
selected_spec.flux.unit)
self._spectrum1d = selected_spec
# Also set the spectral min and max to default to the full range
# This is no longer needed for 1D but is preserved for now pending
# fixes to Cubeviz for multi-subregion subsets
self._window = None
self.spectral_min = selected_spec.spectral_axis[0].value
self.spectral_max = selected_spec.spectral_axis[-1].value
self.spectral_unit = str(selected_spec.spectral_axis.unit)
@observe("spectral_subset_selected")
def _on_spectral_subset_selected(self, event):
# If "Entire Spectrum" selected, reset based on bounds of selected data
if self._spectrum1d is None:
# TODO: this should be removed as soon as the data dropdown component is
# created and defaults at init
return
if self.spectral_subset_selected == "Entire Spectrum":
self._window = None
self.spectral_min = self._spectrum1d.spectral_axis[0].value
self.spectral_max = self._spectrum1d.spectral_axis[-1].value
else:
spec_sub = self.spectral_subset.selected_obj
unit = u.Unit(self.spectral_unit)
if hasattr(spec_sub, "center"):
spreg = SpectralRegion.from_center(spec_sub.center.x * unit,
spec_sub.width * unit)
self._window = (spreg.lower, spreg.upper)
self.spectral_min = spreg.lower.value
self.spectral_max = spreg.upper.value
[docs] def vue_model_selected(self, event):
# Add the model selected to the list of models
self.temp_model = event
if event == "Polynomial1D":
self.display_order = True
else:
self.display_order = False
def _reinitialize_with_fixed(self):
"""
Reinitialize all component models with current values and the
specified parameters fixed (can't easily update fixed dictionary in
an existing model)
"""
temp_models = []
for m in self.component_models:
fixed = {}
# Set the initial values as quantities to make sure model units
# are set correctly.
initial_values = {p["name"]: u.Quantity(p["value"], p["unit"]) for p in m["parameters"]}
for p in m["parameters"]:
fixed[p["name"]] = p["fixed"]
# Have to initialize with fixed dictionary
temp_model = MODELS[m["model_type"]](name=m["id"], fixed=fixed,
**initial_values, **m.get("model_kwargs", {}))
temp_models.append(temp_model)
return temp_models
[docs] def vue_add_model(self, event):
"""Add the selected model and input string ID to the list of models"""
# validate provided label (only allow "word characters"). These should already be
# stripped by JS in the UI element, but we'll confirm here (especially if this is ever
# extended to have better API-support)
if re.search(r'\W+', self.temp_name):
raise ValueError(f"invalid model component ID {self.temp_name}")
if self.temp_name in [cm['id'] for cm in self.component_models]:
raise ValueError(f"model component ID {self.temp_name} already in use")
new_model = {"id": self.temp_name, "model_type": self.temp_model,
"parameters": [], "model_kwargs": {}}
model_cls = MODELS[self.temp_model]
if self.temp_model == "Polynomial1D":
# self.poly_order is the value in the widget for creating
# the new model component. We need to store that with the
# model itself as the value could change for another component.
new_model["model_kwargs"] = {"degree": self.poly_order}
elif self.temp_model == "BlackBody":
new_model["model_kwargs"] = {"output_units": self._units["y"],
"bounds": {"scale": (0.0, None)}}
initial_values = {}
for param_name in get_model_parameters(model_cls, new_model["model_kwargs"]):
# access the default value from the model class itself
default_param = getattr(model_cls, param_name, _EmptyParam(0))
default_units = self._param_units(param_name,
model_type=new_model["model_type"])
if default_param.unit is None:
# then the model parameter accepts unitless, but we want
# to pass with appropriate default units
initial_val = u.Quantity(default_param.value, default_units)
else:
# then the model parameter has default units. We want to pass
# with jdaviz default units (based on x/y units) but need to
# convert the default parameter unit to these units
initial_val = default_param.quantity.to(default_units)
initial_values[param_name] = initial_val
initialized_model = initialize(
MODELS[self.temp_model](name=self.temp_name,
**initial_values,
**new_model.get("model_kwargs", {})),
self._spectrum1d.spectral_axis,
self._spectrum1d.flux)
# need to loop over parameters again as the initializer may have overridden
# the original default value
for param_name in get_model_parameters(model_cls, new_model["model_kwargs"]):
param_quant = getattr(initialized_model, param_name)
new_model["parameters"].append({"name": param_name,
"value": param_quant.value,
"unit": str(param_quant.unit),
"fixed": False})
self._initialized_models[self.temp_name] = initialized_model
new_model["Initialized"] = True
self.component_models = self.component_models + [new_model]
[docs] def vue_remove_model(self, event):
self.component_models = [x for x in self.component_models
if x["id"] != event]
del(self._initialized_models[event])
[docs] def vue_equation_changed(self, event):
# Length is a dummy check to test the infrastructure
if len(self.model_equation) > 20:
self.eq_error = True
[docs] def vue_model_fitting(self, *args, **kwargs):
"""
Run fitting on the initialized models, fixing any parameters marked
as such by the user, then update the displayed parameters with fit
values
"""
if self._warn_if_no_equation():
return
models_to_fit = self._reinitialize_with_fixed()
# Apply mask from selected subset
if self.spectral_subset_selected != "Entire Spectrum":
subset_mask = self.app.get_data_from_viewer("spectrum-viewer",
data_label = self.spectral_subset_selected).mask # noqa
if self._spectrum1d.mask is None:
self._spectrum1d.mask = subset_mask
else:
self._spectrum1d.mask += subset_mask
try:
fitted_model, fitted_spectrum = fit_model_to_spectrum(
self._spectrum1d,
models_to_fit,
self.model_equation,
run_fitter=True)
except AttributeError:
msg = SnackbarMessage("Unable to fit: model equation may be invalid",
color="error", sender=self)
self.hub.broadcast(msg)
return
self._fitted_model = fitted_model
self._fitted_spectrum = fitted_spectrum
self.app.fitted_models[self.model_label] = fitted_model
self.vue_register_spectrum({"spectrum": fitted_spectrum})
# Update component model parameters with fitted values
if type(self._fitted_model) == QuantityModel:
self._update_parameters_from_QM()
else:
self._update_parameters_from_fit()
# Also update the _initialized_models so we can use these values
# as the starting point for cube fitting
self._update_initialized_parameters()
# Reset the data mask in case we use a different subset next time
self._spectrum1d.mask = self._original_mask
[docs] def vue_fit_model_to_cube(self, *args, **kwargs):
if self._warn_if_no_equation():
return
if self.selected_data in self.app.data_collection.labels:
data = self.app.data_collection[self.selected_data]
else: # User selected some subset from spectrum viewer, just use original cube
data = self.app.data_collection[0]
# First, ensure that the selected data is cube-like. It is possible
# that the user has selected a pre-existing 1d data object.
if data.ndim != 3:
snackbar_message = SnackbarMessage(
f"Selected data {self.selected_data} is not cube-like",
color='error',
sender=self)
self.hub.broadcast(snackbar_message)
return
# Get the primary data component
attribute = data.main_components[0]
component = data.get_component(attribute)
temp_values = data.get_data(attribute)
# Transpose the axis order
values = np.moveaxis(temp_values, 0, -1) * u.Unit(component.units)
# We manually create a Spectrum1D object from the flux information
# in the cube we select
wcs = data.coords.sub([WCSSUB_SPECTRAL])
spec = Spectrum1D(flux=values, wcs=wcs)
# TODO: in vuetify >2.3, timeout should be set to -1 to keep open
# indefinitely
snackbar_message = SnackbarMessage(
"Fitting model to cube...",
loading=True, timeout=0, sender=self)
self.hub.broadcast(snackbar_message)
# Retrieve copy of the models with proper "fixed" dictionaries
models_to_fit = self._reinitialize_with_fixed()
try:
fitted_model, fitted_spectrum = fit_model_to_spectrum(
spec,
models_to_fit,
self.model_equation,
run_fitter=True,
window=self._window)
except ValueError:
snackbar_message = SnackbarMessage(
"Cube fitting failed",
color='error', loading=False, sender=self)
self.hub.broadcast(snackbar_message)
raise
# Save fitted 3D model in a way that the cubeviz
# helper can access it.
for m in fitted_model:
temp_label = "{} ({}, {})".format(self.model_label, m["x"], m["y"])
self.app.fitted_models[temp_label] = m["model"]
# Transpose the axis order back
values = np.moveaxis(fitted_spectrum.flux.value, -1, 0)
count = max(map(lambda s: int(next(iter(re.findall(r"\d$", s)), 0)),
self.data_collection.labels)) + 1
label = f"{self.model_label} [Cube] {count}"
# Create new glue data object
output_cube = Data(label=label,
coords=data.coords)
output_cube['flux'] = values
output_cube.get_component('flux').units = \
fitted_spectrum.flux.unit.to_string()
# Add to data collection
self.app.add_data(output_cube, label)
if self.selected_viewer != 'None':
# replace the contents in the selected viewer with the results from this plugin
self.app.add_data_to_viewer(self.viewer_to_id.get(self.selected_viewer),
label, clear_other_data=True)
snackbar_message = SnackbarMessage(
"Finished cube fitting",
color='success', loading=False, sender=self)
self.hub.broadcast(snackbar_message)
[docs] def vue_register_spectrum(self, event):
"""
Add a spectrum to the data collection based on the currently displayed
parameters (these could be user input or fit values).
"""
if self._warn_if_no_equation():
return
# Make sure the initialized models are updated with any user-specified
# parameters
self._update_initialized_parameters()
# Need to run the model fitter with run_fitter=False to get spectrum
if "spectrum" in event:
spectrum = event["spectrum"]
else:
model, spectrum = fit_model_to_spectrum(self._spectrum1d,
self._initialized_models.values(),
self.model_equation,
window=self._window)
self.n_models += 1
label = self.model_label
if label in self.data_collection:
self.app.remove_data_from_viewer('spectrum-viewer', label)
# Remove the actual Glue data object from the data_collection
self.data_collection.remove(self.data_collection[label])
self.app.add_data(spectrum, label)
# Link the result spectrum to the reference data of the spectrum viewer
ref_data = self.app.get_viewer('spectrum-viewer').state.reference_data
data_id = ref_data.world_component_ids[0]
model_id = self.app.session.data_collection[label].world_component_ids[0]
self.app.session.data_collection.add_link(LinkSame(data_id, model_id))
if self.add_replace_results:
self.app.add_data_to_viewer('spectrum-viewer', label)