Source code for code_generation.friend_trees

from __future__ import annotations  # needed for type annotations in > python 3.7

import logging
import ROOT
import json
import os
from time import time
from code_generation.configuration import Configuration
from typing import List, Union, Dict, Set

from code_generation.exceptions import (
    ConfigurationError,
    InvalidOutputError,
    InsufficientShiftInformationError,
    InvalidInputError,
)
from code_generation.producer import Producer, ProducerGroup
from code_generation.rules import ProducerRule

log = logging.getLogger(__name__)


[docs] class FriendTreeConfiguration(Configuration): """ Configuration class for a FriendTree production with the CROWN framework. Based on the main Configuration class, but with a few modifications nessessary for a FriendTree configuration. The biggest differences are * the nominal version of quantities is optional and should only run if the user specifies it * no global scope is required * only one scope is allowed * The ordering is not optimized, but taken directly from the configuration file * information about the input file is required. This information can be provided by a json file, or by providing an input root file. * When using an input root file, only a single sample type and a single scope are allowed """ def __init__( self, era: str, sample: str, scope: Union[str, List[str]], shifts: Set[str], available_sample_types: Union[str, List[str]], available_eras: Union[str, List[str]], available_scopes: Union[str, List[str]], input_information: Union[str, List[str]], ): """Generate a configuration for a FriendTree production. Args: era (str): The era of the sample sample (str): The sample type scope (Union[str, List[str]]): The scope of the sample shifts (Set[str]): The shifts to be applied, can be "all", "none" or a list of shifts available_sample_types (Union[str, List[str]]): A list of available sample types available_eras (Union[str, List[str]]): A list of available eras available_scopes (Union[str, List[str]]): A list of available scopes input_information (Union[str, Dict[str, List[str]]]): Information about the input file. Can be a json file or a root file """ # super().__init__( era, sample, scope, shifts, available_sample_types, available_eras, available_scopes, ) self.run_nominal = False # in the main constructor, the global scope is added to the scopes list. # This is not needed for a friend tree configuration, so we remove it again here if self.global_scope in self.scopes: self.scopes.remove(self.global_scope) self.global_scope = None log.warn(f"Selected scopes: {self.selected_scopes}") # if more than one scope is specified, raise an error if len(self.selected_scopes) > 1: raise ConfigurationError( f"FriendTree configurations can only have one scope, but multiple {self.selected_scopes} were specified" ) if not isinstance(input_information, list): input_information_list = [input_information] else: input_information_list = input_information self.input_quantities_mapping = self._readout_input_information( input_information_list ) # all requested shifts are stored in a seperate varaiable, # they have to be added to all producers later self.requested_shifts = self._determine_requested_shifts(shifts) def _determine_requested_shifts(self, shiftset: Set[str]) -> Dict[str, List[str]]: """Determine the requested shifts from the user input Args: shifts (Union[str, List[str]]): User input for the shifts Returns: List[str]: List of requested shifts """ requested_shifts: Dict[str, List[str]] = {} # first convert shifts to a list shifts = list(shiftset) testshifts = [shift.lower() for shift in shifts] # check if the user has specified "all" or "none" if "all" in testshifts or "none" in testshifts: if len(testshifts) > 1: raise ConfigurationError( "When using 'all' or 'none' as a shift, no other shifts can be specified" ) if testshifts[0] == "all": for scope in self.selected_scopes: requested_shifts[scope] = list( self.input_quantities_mapping[scope].keys() ) # remove "" from the list requested_shifts[scope].remove("") requested_shifts[scope].append("nominal") self.run_nominal = True elif testshifts[0] == "none": for scope in self.selected_scopes: requested_shifts[scope] = ["nominal"] self.run_nominal = True else: # in this case, the user has specified a list of shifts # we have to check if the shifts are valid for scope in self.selected_scopes: requested_shifts[scope] = [] for shift in shifts: if shift == "nominal": self.run_nominal = True elif shift not in self.input_quantities_mapping[scope].keys(): raise InsufficientShiftInformationError( shift, list(self.input_quantities_mapping[scope].keys()) ) requested_shifts[scope].append(shift) return requested_shifts def _readout_input_information( self, input_information_list: Union[List[str], List[Dict[str, List[str]]]], ) -> Dict[str, Dict[str, List[str]]]: def update_input_information(existing_data, new_data): if existing_data == {}: return new_data else: # otherwise we have to merge the contents, while not overwriting existing data for scope in new_data.keys(): if scope not in existing_data.keys(): existing_data[scope] = {} for shift in new_data[scope].keys(): if shift not in existing_data[scope].keys(): existing_data[scope][shift] = [] for quantity in new_data[scope][shift]: if quantity not in existing_data[scope][shift]: existing_data[scope][shift].append(quantity) return existing_data # first check if the input is a root file or a json file data = {} for input_information in input_information_list: log.info(f"adding input information from {input_information}") if isinstance(input_information, str): if input_information.endswith(".root"): data = update_input_information( data, self._readout_input_root_file(input_information) ) elif input_information.endswith(".json"): data = update_input_information( data, self._readout_input_json_file(input_information) ) else: error_message = f"\n Input information file {input_information} is not a json or root file \n" error_message += ( " Did you forget to specify the input information file? \n" ) error_message += " The input information has to be a json file or a root file \n" error_message += " and added to the cmake command via the -DQUANTITIESMAP=... option" raise ConfigurationError(error_message) return data def _readout_input_root_file( self, input_file: str ) -> Dict[str, Dict[str, List[str]]]: """Read the shift_quantities_map from the input root file and return it as a dictionary Args: input_file (str): Path to the input root file Returns: Dict[str, List[str]]: Dictionary containing the shift_quantities_map """ data = {} if len(self.selected_scopes) > 1: raise ConfigurationError( "When using an input root file, only a single scope is possible" ) start = time() log.debug(f"Reading quantities information from {input_file}") # Load dict parsing lib lib_path = os.path.abspath("build/libMyDicts.so") # Physical file check if not os.path.exists(lib_path): log.error(f"Missing library: {lib_path}") # Evaluate ROOT-specific return codes result = ROOT.gSystem.Load(lib_path) if result < 0: err_type = ( "Version mismatch" if result == -2 else "Linker error/Missing dependency" ) log.error(f"Load failed ({result}): {err_type} for {lib_path}") f = ROOT.TFile.Open(input_file) # type: ignore name = "shift_quantities_map" m = f.Get(name) for shift, quantities in m: data[str(shift)] = [str(quantity) for quantity in quantities] f.Close() log.debug( f"Reading quantities information took {round(time() - start,2)} seconds" ) return {list(self.selected_scopes)[0]: data} def _readout_input_json_file( self, input_file: str ) -> Dict[str, Dict[str, List[str]]]: """Read the shift_quantities_map from the input json file and return it as a dictionary Args: input_file (str): Path to the input json file Returns: Dict[str, List[str]]: Dictionary containing the shift_quantities_map """ with open(input_file) as f: data = json.load(f) # json file structure is: {era: {sampletype: {scope: {shift: [quantities]}}} if self.era not in data: errorstring = ( f"Era {self.era} not found in input information file {input_file}.\n" ) errorstring += f"Available eras are: {data.keys()}" raise ConfigurationError(errorstring) if self.sample not in data[self.era].keys(): errorstring = f"Sampletype {self.sample} not found in input information file {input_file}.\n" errorstring += f"Available sampletypes are: {data[self.era].keys()}" raise ConfigurationError(errorstring) if not set(self.selected_scopes).issubset( set(data[self.era][self.sample].keys()) ): errorstring = f"Scopes {self.selected_scopes} not found in input information file {input_file}.\n" errorstring += f"Available scopes are: {data[self.era][self.sample].keys()}" raise ConfigurationError(errorstring) else: data = data[self.era][self.sample] return data
[docs] def optimize(self) -> None: """ Function used to optimize the FriendTreeConfiguration. In this case, no ordering optimization is performed. Optimizaion steps are: 1. Apply rules 2. Add all requested shifts to all producers. This addition is trivial, since the shifted quantities are already available in the input file 3. Remove empty scopes Args: None Returns: None """ self._apply_rules() self._add_requested_shifts() self._remove_empty_scopes() self._validate_inputs()
def _add_requested_shifts(self) -> None: # first shift the output quantities for scope in self.selected_scopes: for shift in self.requested_shifts[scope]: if shift != "nominal": shiftname = "__" + shift for producer in self.producers[scope]: # second step is to shift the inputs of the producer self._shift_producer_inputs(producer, shift, shiftname, scope) self.shifts[scope][shiftname] = {} def _shift_producer_inputs( self, producer: Union[Producer, ProducerGroup], shift: str, shiftname: str, scope: str, ) -> None: """Function used to determine which inputs of a producer have to be shifted. If none of the inputs of a producer is available in the shift_quantities_map, the producer is skipped. Args: producer (Union[Producer, ProducerGroup]): The producer to be checked and possibly shifted shift (str): the shift to be added shiftname (str): the name of the shift to be added scope (str): The scope to be checked """ log.debug("Shifting inputs of producer %s", producer) # if the producer is not of Type ProducerGroup we can directly shift the inputs if isinstance(producer, Producer): inputs = producer.get_inputs(scope) log.debug("Inputs of producer %s: %s", producer, inputs) # only shift if necessary if shift in self.input_quantities_mapping[scope].keys(): inputs_to_shift = [] for input_quantity in inputs: if ( input_quantity.name in self.input_quantities_mapping[scope][shift] ): inputs_to_shift.append(input_quantity) if len(inputs_to_shift) > 0: log.debug("Adding shift %s to producer %s", shift, producer) producer.shift(shiftname, scope) log.debug( f"Shifting inputs {inputs_to_shift} of producer {producer} by {shift}" ) producer.shift_inputs(shiftname, scope, inputs_to_shift) else: log.info( f"no inputs to shift for producer {producer} and shift {shift}, skipping" ) elif isinstance(producer, ProducerGroup): for producer in producer.producers[scope]: self._shift_producer_inputs(producer, shift, shiftname, scope) def _validate_outputs(self) -> None: """ Function used to validate the defined outputs. If an output is requested in the configuration, but is not available, since no producer will be able to produce it, an error is raised. Args: None Returns: None """ for scope in [scope for scope in self.scopes]: required_outputs = set(output for output in self.outputs[scope]) # merge the two sets of outputs provided_outputs = self.available_outputs[scope][self.sample] missing_outputs = required_outputs - provided_outputs if len(missing_outputs) > 0: raise InvalidOutputError(scope, missing_outputs) def _validate_inputs(self) -> None: """ The `_validate_inputs` function checks if all required inputs for each producer in the given scopes are available, and raises an error if any inputs are missing. """ for scope in [scope for scope in self.scopes]: # get all inputs of all producers required_inputs = set() available_inputs = set() for producer in self.producers[scope]: required_inputs = required_inputs | set( [x.name for x in producer.get_inputs(scope)] ) available_inputs = available_inputs | set( [x.name for x in producer.get_outputs(scope)] ) # get all available inputs for input_quantitiy in self.input_quantities_mapping[scope][""]: available_inputs.add(input_quantitiy) # now check if all inputs are available missing_inputs = required_inputs - available_inputs if len(missing_inputs) > 0: for producer in self.producers[scope]: if ( len( missing_inputs & set([x.name for x in producer.get_inputs(scope)]) ) > 0 ): log.error(f"Missing inputs for {producer}") log.error(f"| Producer inputs: {producer.get_inputs(scope)}") log.error( f"| Missing inputs: {missing_inputs & set([ x.name for x in producer.get_inputs(scope)])}" ) raise InvalidInputError(scope, missing_inputs)
[docs] def add_modification_rule( self, scopes: Union[str, List[str]], rule: ProducerRule ) -> None: """ Function used to add a rule to the configuration. Args: scopes: The scopes to which the rule should be added. This can be a list of scopes or a single scope. rule: The rule to be added. This must be a ProducerRule object. Returns: None """ if not isinstance(rule, ProducerRule): raise TypeError("Rule must be of type ProducerRule") if not isinstance(scopes, list): scopes = [scopes] rule.set_available_sampletypes(self.available_sample_types) rule.set_scopes(scopes) # TODO Check if this works without a global scope if self.global_scope is not None: rule.set_global_scope(self.global_scope) self.rules.add(rule)
[docs] def expanded_configuration(self) -> Configuration: """Function used to generate an expanded version of the configuration, where all shifts are applied. This expanded configuration is used by the code generator to generate the C++ code. Returns: Configuration: Expanded configuration """ expanded_configuration = {} for scope in self.scopes: expanded_configuration[scope] = {} if self.run_nominal: log.debug("Adding nominal in scope {}".format(scope)) if scope not in self.config_parameters.keys(): raise ConfigurationError( "Scope {} not found in configuration parameters".format(scope) ) expanded_configuration[scope]["nominal"] = self.config_parameters[scope] if len(self.shifts[scope]) > 0: for shift in self.shifts[scope]: log.debug("Adding shift {} in scope {}".format(shift, scope)) log.debug(" {}".format(self.shifts[scope][shift])) try: expanded_configuration[scope][shift] = ( self.config_parameters[scope] | self.shifts[scope][shift] ) except KeyError: expanded_configuration[scope][shift] = {} expanded_configuration[scope][shift] = ( self.config_parameters[scope] | self.shifts[scope][shift] ) # check if any shift (including the nominal) is run, if not, exit with an error if not any( [len(expanded_configuration[scope]) > 0 for scope in expanded_configuration] ): error_msg = "Nothing to run, is the configuration valid? \n Provided Configuration: \n {}".format( self ) raise ConfigurationError(error_msg) self.config_parameters = expanded_configuration return self