import logging
from typing import List, Union
from code_generation.producer import (
CollectProducersOutput,
Producer,
ProducerGroup,
TProducerInput,
TProducerStore,
)
from code_generation.quantity import QuantitiesStore
from code_generation.exceptions import SampleRuleConfigurationError
log = logging.getLogger(__name__)
[docs]
class ProducerRule:
def __init__(
self,
producers: TProducerInput,
samples: Union[str, List[str]] = [],
exclude_samples: Union[str, List[str]] = [],
scopes: Union[str, List[str]] = "global",
invert: bool = False,
update_output: bool = True,
):
"""ProducerRule is a base class for all rules that modify producers.
Args:
producers: A list of producers or producer groups to be modified.
samples: A list of samples, for which the rule should be applied. Only one of samples and exclude_samples can be defined.
exclude_samples: A list of samples, for which the rule should not be applied. Only one of samples and exclude_samples can be defined.
scopes: The scopes, in which the rule should be applied. Defaults to "global".
invert: If set, the invert of the rule is applied. Defaults to False.
update_output: If set, the output quantities are updated. Defaults to True.
"""
if isinstance(producers, ProducerGroup) or isinstance(producers, Producer):
producers = [producers]
self.producers = producers
if isinstance(scopes, str):
scopes = [scopes]
self.scopes = scopes
self.invert = invert
self.update_output = update_output
self.global_scope = "global"
if isinstance(exclude_samples, str):
self.exclude_samples = [exclude_samples]
else:
self.exclude_samples = exclude_samples
if isinstance(samples, str):
self.samples = [samples]
else:
self.samples = samples
[docs]
def set_available_sampletypes(self, available_samples) -> None:
# sanitize input
if isinstance(available_samples, str):
self.available_samples = [available_samples]
else:
self.available_samples = available_samples
# make sure that either samples or exclude_samples are defined
if self.exclude_samples == [] and self.samples == []:
raise ValueError(
f"ProducerRule: Either samples or exclude_samples have to be defined!: (Rule: {self}, Samples: {self.samples}, Excluded Samples: {self.exclude_samples})"
)
if self.exclude_samples != [] and self.samples != []:
raise ValueError(
f"ProducerRule: Both samples and are exclude_samples are defined, pick one!: (Rule: {self}, Samples: {self.samples}, Excluded Samples: {self.exclude_samples})"
)
# make sure that the sampletypes are valid
self.validate_sampletypes(self.samples)
self.validate_sampletypes(self.exclude_samples)
# if exclude_samples are defined, we have to contstruct the list of samples them from the list of available samples
if self.exclude_samples != []:
self.samples = list(set(self.available_samples) - set(self.exclude_samples))
[docs]
def set_scopes(self, scopes: List[str]) -> None:
if isinstance(scopes, str):
scopes = [scopes]
self.scopes = scopes
[docs]
def affected_scopes(self) -> List[str]:
return self.scopes
[docs]
def affected_producers(self) -> List[Union[Producer, ProducerGroup]]:
return self.producers
[docs]
def set_global_scope(self, global_scope: str) -> None:
self.global_scope = global_scope
[docs]
def validate_sampletypes(self, sampletypes: List[str]) -> None:
"""Function to check, if a rule is applicable, or if one of the defined samples is not available.
Args:
available_samples (List[str]): List of available samples.
Returns:
None
"""
for sample in sampletypes:
if sample not in self.available_samples:
raise SampleRuleConfigurationError(sample, self, self.available_samples)
# Evaluate whether modification should be applied depending on sample and inversion flag
[docs]
def is_applicable(self, sample: str) -> bool:
applicable = sample in self.samples
if self.invert:
applicable = not applicable
return applicable
# Placeholder for the actual operation on a list. To be overwritten by inheriting classes
[docs]
def update_producers(
self,
producers_to_be_updated: TProducerStore,
unpacked_producers: TProducerStore,
) -> None:
log.error("Operation not implemented for ProducerRule base class!")
[docs]
def update_outputs(self, outputs_to_be_updated: QuantitiesStore) -> None:
log.error("Operation not implemented for ProducerRule base class!")
[docs]
def apply(
self,
sample: str,
producers_to_be_updated: TProducerStore,
unpacked_producers: TProducerStore,
outputs_to_be_updated: QuantitiesStore,
) -> None:
if self.is_applicable(sample):
log.warning(f"Applying rule {self} for sample {sample}")
log.debug("For sample {}, applying >> {} ".format(sample, self))
self.update_producers(producers_to_be_updated, unpacked_producers)
self.update_outputs(outputs_to_be_updated)
# Modifier class that can remove producers from lists
[docs]
class RemoveProducer(ProducerRule):
def __str__(self) -> str:
return "ProducerRule - remove {} for {} in scopes {}".format(
self.producers, self.samples, self.scopes
)
def __repr__(self) -> str:
return "ProducerRule - remove {} for {} in scopes {}".format(
self.producers, self.samples, self.scopes
)
[docs]
def update_producers(
self,
producers_to_be_updated: TProducerStore,
unpacked_producers: TProducerStore,
) -> None:
log.debug("Producers to be updated: {}".format(producers_to_be_updated))
log.debug("scopes: {}".format(self.scopes))
log.debug("Producers to be removed: {}".format(self.producers))
for scope in self.scopes:
for producer in self.producers:
if producer in producers_to_be_updated[scope]:
log.debug(
"RemoveProducer: Removing {} from producer in scope {}".format(
producer, scope
)
)
producers_to_be_updated[scope].remove(producer)
else:
# in this case, the producer does not exist, possibly because it is part of a producer group,
# so we have to check this further
if producer in unpacked_producers[scope].keys():
# if the producer is part of a producer group, we have to remove the whole group
# and add all remaining producers from the group to the list of producers to be updated
log.debug("Found {} within a producer group".format(producer))
corresponding_producer_group = unpacked_producers[scope][
producer
]
log.debug(
"Removing {} from producer group {}".format(
producer, corresponding_producer_group
)
)
log.debug(
"Replacing {} with its unpacked producers".format(
corresponding_producer_group
)
)
producers_to_be_updated[scope].remove(
corresponding_producer_group
)
for unpacked_producer in unpacked_producers[scope]:
if (
unpacked_producers[scope][unpacked_producer]
== corresponding_producer_group
):
producers_to_be_updated[scope].append(unpacked_producer)
# now remove the producer we initially wanted to remove
log.debug(producers_to_be_updated[scope])
producers_to_be_updated[scope].remove(producer)
else:
raise ConnectionError(
"Producer {} not found in scope {}, cannot apply \n {}".format(
producer, scope, self
)
)
[docs]
def update_outputs(self, outputs_to_be_updated: QuantitiesStore) -> None:
if self.update_output:
outputs: QuantitiesStore = {}
# if the producer is in the global scope, we add the output to all running scopes
if self.scopes == [self.global_scope]:
log.debug(
"Updating outputs for producer in global scope --> adding output to all scopes"
)
scopes = outputs_to_be_updated.keys()
for scope in scopes:
outputs[scope] = CollectProducersOutput(
self.producers, self.global_scope
)
else:
scopes = self.scopes
for scope in scopes:
outputs[scope] = CollectProducersOutput(self.producers, scope)
for scope in scopes:
for output in outputs[scope]:
if output in outputs_to_be_updated[scope]:
log.debug(
"RemoveProducer: Removing {} from outputs in scope {}".format(
output, scope
)
)
outputs_to_be_updated[scope].remove(output)
# Modifier class that can append producers to lists
[docs]
class AppendProducer(ProducerRule):
def __str__(self) -> str:
return "ProducerRule - add {} for {} in scopes {}".format(
self.producers, self.samples, self.scopes
)
def __repr__(self) -> str:
return "ProducerRule - add {} for {} in scopes {}".format(
self.producers, self.samples, self.scopes
)
[docs]
def update_producers(
self,
producers_to_be_updated: TProducerStore,
unpacked_producers: TProducerStore,
) -> None:
for scope in self.scopes:
for producer in self.producers:
log.debug(
"AppendProducer: Adding {} to producers in scope {}".format(
producer, scope
)
)
producers_to_be_updated[scope].append(producer)
[docs]
def update_outputs(self, outputs_to_be_updated: QuantitiesStore) -> None:
if self.update_output:
outputs: QuantitiesStore = {}
# if the producer is in the global scope, we add the output to all running scopes
if self.scopes == [self.global_scope]:
log.debug(
"Updating outputs for producer in global scope --> adding output to all scopes"
)
scopes = outputs_to_be_updated.keys()
for scope in scopes:
outputs[scope] = CollectProducersOutput(
self.producers, self.global_scope
)
else:
scopes = self.scopes
for scope in scopes:
outputs[scope] = CollectProducersOutput(self.producers, scope)
for scope in scopes:
for output in outputs[scope]:
log.debug(
"AppendProducer: Adding {} to outputs in scope {}".format(
output, scope
)
)
outputs_to_be_updated[scope].add(output)
[docs]
class ReplaceProducer(ProducerRule):
def __str__(self) -> str:
return "ProducerRule - replace {} with {} for {} in scopes {}".format(
self.producers[0], self.producers[1], self.samples, self.scopes
)
def __repr__(self) -> str:
return "ProducerRule - replace {} for {} in scopes {}".format(
self.producers, self.samples, self.scopes
)
[docs]
def update_producers(
self,
producers_to_be_updated: TProducerStore,
unpacked_producers: TProducerStore,
) -> None:
log.debug("Producers to be updated: {}".format(producers_to_be_updated))
log.debug("scopes: {}".format(self.scopes))
log.debug("Producers to be replaced: {}".format(self.producers))
producer = self.producers[0]
new_producer = self.producers[1]
for scope in self.scopes:
if producer in producers_to_be_updated[scope]:
log.debug(
"ReplaceProducer: Replace {} from producer in scope {} with {}".format(
producer, scope, new_producer
)
)
producers_to_be_updated[scope].remove(producer)
producers_to_be_updated[scope].append(new_producer)
else:
# in this case, the producer does not exist, possibly because it is part of a producer group,
# so we have to check this further
if producer in unpacked_producers[scope].keys():
# if the producer is part of a producer group, we have to remove the whole group
# and add all remaining producers from the group to the list of producers to be updated
log.debug("Found {} within a producer group".format(producer))
corresponding_producer_group = unpacked_producers[scope][producer]
log.debug(
"Replacing {} from producer group {} with {}".format(
producer, corresponding_producer_group, new_producer
)
)
log.debug(
"Replacing {} with its unpacked producers".format(
corresponding_producer_group
)
)
producers_to_be_updated[scope].remove(corresponding_producer_group)
for unpacked_producer in unpacked_producers[scope]:
if (
unpacked_producers[scope][unpacked_producer]
== corresponding_producer_group
):
producers_to_be_updated[scope].append(unpacked_producer)
# now remove the producer we initially wanted to remove
log.debug(producers_to_be_updated[scope])
producers_to_be_updated[scope].remove(producer)
producers_to_be_updated[scope].append(new_producer)
else:
raise ConnectionError(
"Producer {} not found in scope {}, cannot apply \n {}".format(
producer, scope, self
)
)
[docs]
def update_outputs(self, outputs_to_be_updated: QuantitiesStore) -> None:
if self.update_output:
removed_outputs: QuantitiesStore = {}
added_outputs: QuantitiesStore = {}
# if the producer is in the global scope, we add the output to all running scopes
if self.scopes == [self.global_scope]:
log.debug(
"Updating outputs for producer in global scope --> adding output to all scopes"
)
scopes = outputs_to_be_updated.keys()
for scope in scopes:
removed_outputs[scope] = CollectProducersOutput(
[self.producers[0]], self.global_scope
)
added_outputs[scope] = CollectProducersOutput(
[self.producers[1]], self.global_scope
)
else:
log.debug(
f"Updating outputs for producer in {self.scopes} scope --> adding output to all scopes"
)
scopes = self.scopes
for scope in scopes:
removed_outputs[scope] = CollectProducersOutput(
[self.producers[0]], scope
)
added_outputs[scope] = CollectProducersOutput(
[self.producers[1]], scope
)
if added_outputs == removed_outputs:
log.debug(
f"Outputs {added_outputs} are identical, no need to update outputs"
)
else:
log.warning(
f"Outputs {added_outputs} that should be added are not identical to the removed outputs {removed_outputs}"
)
for scope in scopes:
outputs_after_replace = []
for added_output in added_outputs[scope]:
if added_output in outputs_to_be_updated[scope]:
outputs_after_replace.append(added_output)
for removed_output in removed_outputs[scope]:
if removed_output in outputs_to_be_updated[scope]:
log.debug(
"ReplaceProducer: Removing {} from outputs in scope {}".format(
removed_output, scope
)
)
outputs_to_be_updated[scope].remove(removed_output)
for added_output in added_outputs[scope]:
if added_output in outputs_after_replace:
log.debug(
"ReplaceProducer: Adding {} from outputs in scope {}".format(
added_output, scope
)
)
outputs_to_be_updated[scope].add(added_output)