Source code for ndscan.experiment.entry_point

r"""
Top-level functions for launching :class:`.ExpFragment`\ s and their scans from the rest
of the ARTIQ ``HasEnvironment`` universe.

The two main entry points into the :class:`.ExpFragment` universe are

 * scans (with axes and overrides set from the dashboard UI) via
   :meth:`make_fragment_scan_exp`, and
 * manually launched fragments from vanilla ARTIQ ``EnvExperiment``\ s using
   :meth:`run_fragment_once` or :meth:`create_and_run_fragment_once`.
"""

import logging
import random
import time
from collections import OrderedDict
from collections.abc import Callable, Iterable
from contextlib import suppress
from functools import reduce
from typing import Any

from artiq.coredevice.exceptions import RTIOUnderflow
from artiq.language import (
    EnvExperiment,
    HasEnvironment,
    PYONValue,
    TerminationRequested,
    kernel,
    portable,
    rpc,
)

from ..utils import (
    PARAMS_ARG_KEY,
    SCHEMA_REVISION,
    SCHEMA_REVISION_KEY,
    NoAxesMode,
    merge_no_duplicates,
    shorten_to_unambiguous_suffixes,
    strip_suffix,
)
from .default_analysis import AnnotationContext
from .fragment import (
    ExpFragment,
    Fragment,
    RestartKernelTransitoryError,
    TransitoryError,
)
from .parameters import ParamBase, ParamStore
from .result_channels import (
    AppendingDatasetSink,
    LastValueSink,
    ResultChannel,
    ScalarDatasetSink,
)
from .scan_generator import GENERATORS, INT_GENERATORS, ScanGenerator, ScanOptions
from .scan_runner import (
    ScanAxis,
    ScanSpec,
    describe_analyses,
    describe_scan,
    filter_default_analyses,
    select_runner_class,
)
from .utils import dump_json, is_kernel, to_metadata_broadcast_type

__all__ = [
    "ArgumentInterface",
    "TopLevelRunner",
    "make_fragment_scan_exp",
    "run_fragment_once",
    "create_and_run_fragment_once",
]

# Hack: Only export FragmentScanExperiment when imported from Sphinx autodoc, so
# documentation is generated, but ARTIQ master doesn't try (and fail) to instantiate it
# when building the experiment list for modules using ndscan.experiment star imports.
import sys

if "sphinx" in sys.modules:
    __all__.append("FragmentScanExperiment")

logger = logging.getLogger(__name__)


[docs] class ScanSpecError(Exception): """Raised when the scan specification passed in :data:`PARAMS_ARG_KEY` is not valid for the given fragment."""
def get_class_pretty_name(cls: type[object]) -> str: """Return the "pretty name" for the passed class, as used by ARTIQ to display in the experiment explorer. (ARTIQ unfortunately doesn't expose this, so reimplement the lookup here.) """ doc = cls.__doc__ if not doc: return cls.__name__ return strip_suffix(doc.strip().splitlines()[0].strip(), ".")
[docs] class FragmentScanExperiment(EnvExperiment): """Implements possibly (trivial) scans of an :class:`.ExpFragment`, with overrides and scan axes as specified by the :data:`PARAMS_ARG_KEY` dataset, and result channels being broadcasted to datasets. See :meth:`make_fragment_scan_exp` for a convenience method to create subclasses for a specific :class:`.ExpFragment`. """ argument_ui = "ndscan"
[docs] def build( self, fragment_init: Callable[[], ExpFragment], max_rtio_underflow_retries: int = 3, max_transitory_error_retries: int = 10, ): """ :param fragment_init: Callable to create the top-level :meth:`ExpFragment` instance. :param max_rtio_underflow_retries: Number of RTIOUnderflows to tolerate per scan point (by simply trying again) before giving up. :param max_transitory_error_retries: Number of transitory errors to tolerate per scan point (by simply trying again) before giving up. """ self.fragment = fragment_init() self.max_rtio_underflow_retries = max_rtio_underflow_retries self.max_transitory_error_retries = max_transitory_error_retries self.args = ArgumentInterface(self, [self.fragment], scannable=True)
[docs] def prepare(self): """Collect parameters to set from both scan axes and simple overrides, and initialise result channels. """ param_stores = self.args.make_override_stores() spec, no_axes_mode, skip_on_persistent_transitory_error = ( self.args.make_scan_spec() ) for ax in spec.axes: fqn = ax.param_schema["fqn"] param_stores.setdefault(fqn, []).append((ax.path, ax.param_store)) self.fragment.init_params(param_stores) for fqn, pairs in param_stores.items(): for path, store in pairs: if not store._handles: # This isn't strictly problem; we could continue on just fine, or # make this just a warning. However, there is no reason for the user # to specify an override spec that will be silently ignored in the # first place, and failing loudly prevents any unnoticed issues if # the experiment code is modified for an already scheduled servo/ # already open argument UI window/… raise ValueError( f"Override for '{fqn}' in path '{path}' did not " + "match any parameters (likely due to changes " + "since made to the experiment code; try " + "Recompute All Arguments)." ) self.tlr = TopLevelRunner( self, fragment=self.fragment, spec=spec, no_axes_mode=no_axes_mode, max_rtio_underflow_retries=self.max_rtio_underflow_retries, max_transitory_error_retries=self.max_transitory_error_retries, skip_on_persistent_transitory_error=skip_on_persistent_transitory_error, )
[docs] def run(self): name = get_class_pretty_name(self.fragment.__class__) self.tlr.create_applet(title=f"{name} ({self.fragment.fqn})") with suppress(TerminationRequested): self.tlr.run()
[docs] def analyze(self): self.tlr.analyze()
def select_generator_class( scan_type: str, param_type: str ) -> type[ScanGenerator] | None: """Return the :class:`ScanGenerator` subclass implementing ``scan_type`` for a parameter of the given type (as per the schema), or ``None`` if no implementation matches. Integer parameters use the :data:`INT_GENERATORS` specialisations where available, so each integer is visited at most once and refining scans terminate. """ if param_type == "int": if (cls := INT_GENERATORS.get(scan_type)) is not None: return cls return GENERATORS.get(scan_type) class ArgumentInterface(HasEnvironment): def build(self, fragments: list[Fragment], scannable: bool = False) -> None: self._fragments = fragments instances = dict[str, list[str]]() self._schemata = dict[str, dict]() self._sample_instances = dict[str, ParamBase]() always_shown_params = [] for fragment in fragments: fragment._collect_params(instances, self._schemata, self._sample_instances) for handle in fragment.get_always_shown_params(): path = handle.owner._stringize_path() try: param = handle.owner._free_params[handle.name] always_shown_params += [(param.fqn, path)] except KeyError: # This is benign if a parameter explicitly listed in a fragment's # get_always_shown_params() is subsequently overridden in a parent # fragment/subscan/… logger.debug( "Parameter '%s' specified in get_always_shown_params()" " is not a free parameter of fragment '%s'", handle.name, path, ) desc = { "instances": instances, "schemata": self._schemata, "always_shown": always_shown_params, "overrides": {}, } if scannable: desc["scan"] = { "axes": [], "num_repeats": 1, "no_axes_mode": "single", "randomise_order_globally": False, } self._params = self.get_argument(PARAMS_ARG_KEY, PYONValue(default=desc)) def make_override_stores(self) -> dict[str, list[tuple[str, ParamStore]]]: stores = {} for fqn, specs in self._params.get("overrides", {}).items(): try: store_type = self._sample_instances[fqn].StoreType except KeyError: raise KeyError( "Experiment does not have parameters matching " f"overrride for FQN '{fqn}' " + "(likely due to outdated argument editor after " + "changes to experiment; try Recompute All Arguments)" ) stores[fqn] = [ ( s["path"], store_type( (fqn, s["path"]), store_type.value_from_pyon(s["value"]) ), ) for s in specs ] return stores def make_scan_spec(self) -> tuple[ScanSpec, NoAxesMode, bool]: scan = self._params.get("scan", {}) generators = [] axes = [] for axspec in scan.get("axes", []): fqn = axspec["fqn"] pathspec = axspec["path"] try: schema = self._schemata[fqn] store_type = self._sample_instances[fqn].StoreType except KeyError: raise KeyError( "Experiment does not have parameters matching " f"scan axis with FQN '{fqn}' " + "(likely due to outdated argument editor after " + "changes to experiment; try Recompute All Arguments)" ) generator_class = select_generator_class(axspec["type"], schema["type"]) if not generator_class: raise ScanSpecError( "Axis type '{}' not implemented".format(axspec["type"]) ) generator = generator_class(**axspec["range"]) generators.append(generator) first_value = generator.points_for_level(0, random)[0] store = store_type((fqn, pathspec), store_type.value_from_pyon(first_value)) axes.append(ScanAxis(schema, pathspec, store)) options = ScanOptions( scan.get("num_repeats", 1), scan.get("num_repeats_per_point", 1), scan.get("randomise_order_globally", False), ) spec = ScanSpec(axes, generators, options) no_axes_mode = NoAxesMode[scan.get("no_axes_mode", "single")] skip_on_persistent_transitory_error = scan.get( "skip_on_persistent_transitory_error", False ) return spec, no_axes_mode, skip_on_persistent_transitory_error class TopLevelRunner(HasEnvironment): def build( self, fragment: ExpFragment, spec: ScanSpec, no_axes_mode: NoAxesMode = NoAxesMode.single, max_rtio_underflow_retries: int = 3, max_transitory_error_retries: int = 10, dataset_prefix: str | None = None, skip_on_persistent_transitory_error: bool = False, ): self.fragment = fragment self.spec = spec self.max_rtio_underflow_retries = max_rtio_underflow_retries self.max_transitory_error_retries = max_transitory_error_retries self.skip_on_persistent_transitory_error = skip_on_persistent_transitory_error self.setattr_device("ccb") self.setattr_device("core") self.setattr_device("scheduler") if dataset_prefix is None: dataset_prefix = f"ndscan.rid_{self.scheduler.rid}." elif dataset_prefix and dataset_prefix[-1] != ".": # Add trailing dot to dataset prefix if not given – the same bare prefix # mushed into all the ndscan datasets isn't what a user with an intact sense # of style would intend. dataset_prefix += "." self.dataset_prefix = dataset_prefix # FIXME: We save these as individual booleans as enums crash the ARTIQ compiler. self._continue_running = False self._is_time_series = False if not self.spec.axes: self._continue_running = no_axes_mode != NoAxesMode.single if no_axes_mode == NoAxesMode.time_series: self._is_time_series = True param_schema = { "type": "float", "fqn": "timestamp", "description": "Elapsed time", "default": "0.0", "spec": {"min": 0.0, "unit": "s", "scale": 1.0}, } self.spec.axes = [ScanAxis(param_schema, "*", None)] # Initialise result channels. chan_dict = {} self.fragment._collect_result_channels(chan_dict) chan_name_map = _shorten_result_channel_names(chan_dict.keys()) self._scan_result_sinks = {} self._short_child_channel_names = {} for path, channel in chan_dict.items(): if not channel.save_by_default: continue name = chan_name_map[path].replace("/", "_") self._short_child_channel_names[channel] = name if self.spec.axes: sink = AppendingDatasetSink( self, self.dataset_prefix + "points.channel_" + name ) else: sink = ScalarDatasetSink(self, self.dataset_prefix + "point." + name) channel.set_sink(sink) self._scan_result_sinks[channel] = sink # Filter analyses, set up analysis result channels, and keep track of all the # names in the annotation context. self._analyses = ( [] if self._is_time_series else filter_default_analyses(self.fragment, self.spec.axes) ) self._analysis_results = reduce( lambda x, y: merge_no_duplicates(x, y, kind="analysis result"), (a.get_analysis_results() for a in self._analyses), {}, ) for name, channel in self._analysis_results.items(): channel.set_sink( ScalarDatasetSink(self, self.dataset_prefix + "analysis_result." + name) ) axis_indices = {} for i, axis in enumerate(self.spec.axes): axis_indices[(axis.param_schema["fqn"], axis.path)] = i self._annotation_context = AnnotationContext( lambda handle: axis_indices[handle._store.identity], lambda channel: self._short_child_channel_names[channel], lambda channel: True, ) self._coordinate_sinks = None self.fragment.prepare() def run(self): """Run the (possibly trivial) scan.""" self._broadcast_metadata() if not self.spec.axes and not self._is_time_series: self._run_continuous() return None, {c: s.get_last() for c, s in self._scan_result_sinks.items()} if self._is_time_series: self._timestamp_sink = AppendingDatasetSink( self, self.dataset_prefix + "points.axis_0" ) self._coordinate_sinks = [self._timestamp_sink] self._time_series_start = time.monotonic() self._run_continuous() else: runner = select_runner_class(self.fragment)( self, max_rtio_underflow_retries=self.max_rtio_underflow_retries, max_transitory_error_retries=self.max_transitory_error_retries, skip_on_persistent_transitory_error=self.skip_on_persistent_transitory_error, ) self._coordinate_sinks = [ AppendingDatasetSink(self, self.dataset_prefix + f"points.axis_{i}") for i in range(len(self.spec.axes)) ] runner.run(self.fragment, self.spec, self._coordinate_sinks) self._set_completed() return self._make_coordinate_dict(), self._make_value_dict() def _make_coordinate_dict(self): return OrderedDict( ((a.param_schema["fqn"], a.path), s.get_all()) for a, s in zip(self.spec.axes, self._coordinate_sinks) ) def _make_value_dict(self): return {c: s.get_all() for c, s in self._scan_result_sinks.items()} def analyze(self): if self._coordinate_sinks is None: # Continuous scan or got an exception early on, so there is no data to # analyse – gracefully ignore this to keep FragmentScanExperiment # implementation simple. return if not self._analyses: return annotations = [] coordinates = self._make_coordinate_dict() values = self._make_value_dict() for a in self._analyses: annotations += a.execute(coordinates, values, self._annotation_context) if annotations: # Replace existing (online-fit) annotations if any analysis produced custom # ones. This could be made configurable in the future. self.set_dataset( self.dataset_prefix + "annotations", dump_json(annotations), broadcast=True, ) return { name: channel.sink.get_last() for name, channel in self._analysis_results.items() } def _run_continuous(self): self._point_phase = False # TODO: Unify with _FragmentRunner. self.num_current_transitory_errors = 0 self.num_current_underflows = 0 try: while True: # After every pause(), pull in dataset changes (immediately as well to # catch changes between the time the experiment is prepared and when it # is run, to keep the semantics uniform). self.fragment.recompute_param_defaults() try: self.fragment.host_setup() if is_kernel(self.fragment.run_once): done = self._run_continuous_kernel() self.core.comm.close() if done: break else: if self._continuous_loop(): break finally: self.fragment.host_cleanup() self.scheduler.pause() finally: self._set_completed() @kernel def _run_continuous_kernel(self): self.core.reset() return self._continuous_loop() @portable def _continuous_loop(self): # TODO: Unify with _FragmentRunner. try: while not self.scheduler.check_pause(): try: self.fragment.device_setup() self.fragment.run_once() self._finish_continuous_point() if not self._continue_running: return True # One point is now finished, so reset transitory error counters for # the next one. self.num_current_transitory_errors = 0 self.num_current_underflows = 0 except RTIOUnderflow: self.num_current_underflows += 1 if self.num_current_underflows > self.max_rtio_underflow_retries: raise print( "Ignoring RTIOUnderflow (", self.num_current_underflows, "/", self.max_rtio_underflow_retries, ")", ) except RestartKernelTransitoryError: self.num_current_transitory_errors += 1 if ( self.num_current_transitory_errors > self.max_transitory_error_retries ): raise print( "Caught transitory error (", self.num_current_transitory_errors, "/", self.max_transitory_error_retries, "), restarting kernel", ) return False except TransitoryError: self.num_current_transitory_errors += 1 if ( self.num_current_transitory_errors > self.max_transitory_error_retries ): raise print( "Caught transitory error (", self.num_current_transitory_errors, "/", self.max_transitory_error_retries, "), retrying", ) return False finally: self.fragment.device_cleanup() assert False, "Execution never reaches here, return is just to pacify compiler." return True @rpc(flags={"async"}) def _finish_continuous_point(self): if self._is_time_series: self._timestamp_sink.push(time.monotonic() - self._time_series_start) else: self._point_phase = not self._point_phase self.set_dataset( self.dataset_prefix + "point_phase", self._point_phase, broadcast=True ) def _set_completed(self): self.set_dataset(self.dataset_prefix + "completed", True, broadcast=True) def _broadcast_metadata(self): def push(name, value): self.set_dataset(self.dataset_prefix + name, value, broadcast=True) push(SCHEMA_REVISION_KEY, SCHEMA_REVISION) source_prefix = self.get_dataset("system_id", default="rid") push("source_id", f"{source_prefix}_{self.scheduler.rid}") push("completed", False) self._scan_desc = describe_scan( self.spec, self.fragment, self._short_child_channel_names ) self._scan_desc.update( describe_analyses(self._analyses, self._annotation_context) ) self._scan_desc["analysis_results"] = { name: channel.describe() for name, channel in self._analysis_results.items() } for name, value in self._scan_desc.items(): # Flatten arrays/dictionaries to JSON strings for HDF5 compatibility. ds_value = to_metadata_broadcast_type(value) if ds_value is None: push(name, dump_json(value)) else: push(name, ds_value) def create_applet(self, title: str, group: str = "ndscan"): cmd = [ "${python}", "-m ndscan.applet", "--server=${server}", "--port-notify=${port_notify}", "--port-control=${port_control}", f"--prefix={self.dataset_prefix}", ] self.ccb.issue("create_applet", title, " ".join(cmd), group=group) def _shorten_result_channel_names(full_names: Iterable[str]) -> dict[str, str]: return shorten_to_unambiguous_suffixes( full_names, lambda fqn, n: "/".join(fqn.split("/")[-n:]) )
[docs] def make_fragment_scan_exp( fragment_class: type[ExpFragment], *args, max_rtio_underflow_retries: int = 3, max_transitory_error_retries: int = 10, ) -> type[FragmentScanExperiment]: """Create a :class:`FragmentScanExperiment` subclass that scans the given :class:`.ExpFragment`, ready to be picked up by the ARTIQ explorer/… This is the default way of creating scan experiments:: class MyExpFragment(ExpFragment): def build_fragment(self): # ... def run_once(self): # ... MyExpFragmentScan = make_fragment_scan_exp(MyExpFragment) """ class FragmentScanShim(FragmentScanExperiment): def build(self): super().build( lambda: fragment_class(self, [], *args), max_rtio_underflow_retries=max_rtio_underflow_retries, max_transitory_error_retries=max_transitory_error_retries, ) # Take on the name of the fragment class to keep result file names informative. FragmentScanShim.__name__ = fragment_class.__name__ # Use the fragment class docstring to display in the experiment explorer UI. FragmentScanShim.__doc__ = fragment_class.__doc__ return FragmentScanShim
class _FragmentRunner(HasEnvironment): """Object wrapping fragment execution to be able to execute everything in one kernel invocation (no difference for non-kernel fragments). """ def build( self, fragment: ExpFragment, max_rtio_underflow_retries: int, max_transitory_error_retries: int, ): self.fragment = fragment self.max_rtio_underflow_retries = max_rtio_underflow_retries self.max_transitory_error_retries = max_transitory_error_retries self.num_underflows_caught = 0 self.num_transitory_errors_caught = 0 def run(self) -> bool: """Execute device_setup()/run_once(), retrying if nececssary. :return: ``True`` if execution completed, ``False`` if it should be attempted again (RestartKernelTransitoryError). """ # TODO: Unify with FragmentScanExperiment._run_continuous(). if is_kernel(self.fragment.run_once): self.setattr_device("core") return self._run_on_kernel() else: return self._run() @kernel def _run_on_kernel(self): """Force the portable _run() to run on the kernel.""" return self._run() @portable def _run(self): try: while True: try: self.fragment.device_setup() self.fragment.run_once() return True except RTIOUnderflow: self.num_underflows_caught += 1 if self.num_underflows_caught > self.max_rtio_underflow_retries: raise print( "Ignoring RTIOUnderflow (", self.num_underflows_caught, "/", self.max_rtio_underflow_retries, ")", ) except RestartKernelTransitoryError: self.num_transitory_errors_caught += 1 if ( self.num_transitory_errors_caught > self.max_transitory_error_retries ): raise print("Caught transitory error, restarting kernel") return False except TransitoryError: self.num_transitory_errors_caught += 1 if ( self.num_transitory_errors_caught > self.max_transitory_error_retries ): raise print( "Caught transitory error (", self.num_transitory_errors_caught, "/", self.max_transitory_error_retries, "), retrying", ) finally: self.fragment.device_cleanup() assert False, "Execution never reaches here, return is just to pacify compiler." return True
[docs] def run_fragment_once( fragment: ExpFragment, max_rtio_underflow_retries: int = 3, max_transitory_error_retries: int = 10, overrides: dict[str, list[tuple[str, ParamStore]]] = {}, ) -> dict[ResultChannel, Any]: """Initialise the passed fragment and run it once, capturing and returning the values from any result channels. :param max_rtio_error_retries: Number of times to catch RTIOUnderflow exceptions and retry execution. If exceeded, the exception is re-raised for the caller to handle. If ``0``, retrying is disabled entirely. :param max_transitory_error_retries: Number of times to catch transitory error exceptions and retry execution. If exceeded, the exception is re-raised for the caller to handle. If ``0``, retrying is disabled entirely. :param overrides: Any parameter overrides to apply when initialising the fragment; see :meth:`.Fragment.init_params`. :return: A dictionary mapping :class:`ResultChannel` instances to their values (or ``None`` if not pushed to). """ channel_dict = {} fragment._collect_result_channels(channel_dict) sinks = {channel: LastValueSink() for channel in channel_dict.values()} for channel, sink in sinks.items(): channel.set_sink(sink) runner = _FragmentRunner( fragment, fragment, max_rtio_underflow_retries, max_transitory_error_retries ) fragment.init_params(overrides=overrides) fragment.prepare() try: while True: fragment.host_setup() if runner.run(): break finally: fragment.host_cleanup() return {channel: sink.get_last() for channel, sink in sinks.items()}
[docs] def create_and_run_fragment_once( env: HasEnvironment, fragment_class: type[ExpFragment], *args, max_rtio_underflow_retries: int = 3, max_transitory_error_retries: int = 10, overrides: dict[str, list[tuple[str, ParamStore]]] = {}, **kwargs, ) -> dict[str, Any]: """Create an instance of the passed :class:`.ExpFragment` type and runs it once, returning the values pushed to any result channels. Example:: class MyExpFragment(ExpFragment): def build_fragment(self): # ... self.setattr_result("foo") def run_once(self): # ... class MyEnvExperiment(EnvExperiment): def run(self): results = create_and_run_fragment_once(self, MyExpFragment) print(results["foo"]) :param env: The ``HasEnvironment`` to use. :param fragment_class: The :class:`.ExpFragment` class to instantiate. :param args: Any arguments to forward to ``build_fragment()``. :param kwargs: Any keyword arguments to forward to ``build_fragment()``. :return: A dictionary mapping result channel names to their values (or ``None`` if not pushed to). """ results = run_fragment_once( fragment_class(env, [], *args, **kwargs), max_rtio_underflow_retries=max_rtio_underflow_retries, max_transitory_error_retries=max_transitory_error_retries, overrides=overrides, ) shortened_names = _shorten_result_channel_names( channel.path for channel in results.keys() ) return {shortened_names[channel.path]: value for channel, value in results.items()}