Skip to content

packages/engine/scram-node/src/expression.h

Provides the base class for all expressions and units for expression values.

Namespaces

Name
scram
scram::mef
scram::mef::detail

Classes

Name
classscram::mef::Expression <br>Abstract base class for all sorts of expressions to describe events.
classscram::mef::ExpressionFormula <br>CRTP for Expressions with the same formula to evaluate and sample.
classscram::mef::NaryExpression< T, 1 > <br>Unary expression.
classscram::mef::NaryExpression< T, 2 > <br>Binary expression.
classscram::mef::NaryExpression< T, -1 > <br>Multivariate expression.

Source code

cpp
/*
 * Copyright (C) 2014-2018 Olzhas Rakhimov
 * Copyright (C) 2023 OpenPRA ORG Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include <boost/icl/continuous_interval.hpp>
#include <boost/noncopyable.hpp>

namespace scram::mef {

using Interval = boost::icl::continuous_interval<double>;
using IntervalBounds = boost::icl::interval_bounds;

inline bool Contains(const Interval& interval, double value) {
  return boost::icl::contains(interval, Interval::closed(value, value));
}

inline bool IsProbability(const Interval& interval) {
  return boost::icl::within(interval, Interval::closed(0, 1));
}

inline bool IsNonNegative(const Interval& interval) {
  return interval.lower() >= 0;
}

inline bool IsPositive(const Interval& interval) {
  return IsNonNegative(interval) && !Contains(interval, 0);
}

class Expression : private boost::noncopyable {
 public:
  explicit Expression(std::vector<Expression*> args = {});

  virtual ~Expression() = default;

  [[nodiscard]] const std::vector<Expression*>& args() const { return args_; }

  virtual void Validate() const {}

  virtual double value() noexcept = 0;

  virtual Interval interval() noexcept {
    double value = this->value();
    return Interval::closed(value, value);
  }

  virtual bool IsDeviate() noexcept;

  double Sample() noexcept;

  void Reset() noexcept;

 protected:
  void AddArg(Expression* arg) { args_.push_back(arg); }

 private:
  virtual double DoSample() noexcept = 0;

  std::vector<Expression*> args_;  
  double sampled_value_;  
  bool sampled_;  
};

template <class T>
class ExpressionFormula : public Expression {
 public:
  using Expression::Expression;

  double value() noexcept final {
    return static_cast<T*>(this)->Compute(
        [](Expression* arg) { return arg->value(); });
  }

 private:
  double DoSample() noexcept final {
    return static_cast<T*>(this)->Compute(
        [](Expression* arg) { return arg->Sample(); });
  }
};

template <typename T, int N>
class NaryExpression;

template <typename T>
class NaryExpression<T, 1> : public ExpressionFormula<NaryExpression<T, 1>> {
 public:
  explicit NaryExpression(Expression* expression)
      : ExpressionFormula<NaryExpression<T, 1>>({expression}),
        expression_(*expression) {}

  void Validate() const override {}

  Interval interval() noexcept override {
    Interval arg_interval = expression_.interval();
    double max_value = T()(arg_interval.upper());
    double min_value = T()(arg_interval.lower());
    auto min_max = std::minmax(max_value, min_value);
    return Interval::closed(min_max.first, min_max.second);
  }

  template <typename F>
  double Compute(F&& eval) noexcept {
    return T()(eval(&expression_));
  }

 private:
  Expression& expression_;  
};

template <typename T>
class NaryExpression<T, 2> : public ExpressionFormula<NaryExpression<T, 2>> {
 public:
  explicit NaryExpression(Expression* arg_one, Expression* arg_two)
      : ExpressionFormula<NaryExpression<T, 2>>({arg_one, arg_two}) {}

  void Validate() const override {}

  Interval interval() noexcept override {
    Interval interval_one = Expression::args().front()->interval();
    Interval interval_two = Expression::args().back()->interval();
    double max_max = T()(interval_one.upper(), interval_two.upper());
    double max_min = T()(interval_one.upper(), interval_two.lower());
    double min_max = T()(interval_one.lower(), interval_two.upper());
    double min_min = T()(interval_one.lower(), interval_two.lower());
    auto interval_pair = std::minmax({max_max, max_min, min_max, min_min});
    return Interval::closed(interval_pair.first, interval_pair.second);
  }

  template <typename F>
  double Compute(F&& eval) noexcept {
    return T()(eval(Expression::args().front()),
               eval(Expression::args().back()));
  }
};

namespace detail {

void EnsureMultivariateArgs(std::vector<Expression*> args);

}  // namespace detail

template <typename T>
class NaryExpression<T, -1> : public ExpressionFormula<NaryExpression<T, -1>> {
 public:
  explicit NaryExpression(std::vector<Expression*> args)
      : ExpressionFormula<NaryExpression<T, -1>>(std::move(args)) {
    detail::EnsureMultivariateArgs(Expression::args());
  }

  void Validate() const override {}

  Interval interval() noexcept override {
    auto it = Expression::args().begin();
    Interval first_arg_interval = (*it)->interval();
    double max_value = first_arg_interval.upper();
    double min_value = first_arg_interval.lower();
    for (++it; it != Expression::args().end(); ++it) {
      Interval next_arg_interval = (*it)->interval();
      double arg_max = next_arg_interval.upper();
      double arg_min = next_arg_interval.lower();
      double max_max = T()(max_value, arg_max);
      double max_min = T()(max_value, arg_min);
      double min_max = T()(min_value, arg_max);
      double min_min = T()(min_value, arg_min);
      std::tie(min_value, max_value) =
          std::minmax({max_max, max_min, min_max, min_min});
    }
    assert(min_value <= max_value);
    return Interval::closed(min_value, max_value);
  }

  template <typename F>
  double Compute(F&& eval) noexcept {
    auto it = Expression::args().begin();
    double result = eval(*it);
    for (++it; it != Expression::args().end(); ++it) {
      result = T()(result, eval(*it));
    }
    return result;
  }
};

void EnsureProbability(Expression* expression,
                       const char* type = "probability");

void EnsurePositive(Expression* expression, const char* description);

void EnsureNonNegative(Expression* expression, const char* description);

void EnsureWithin(Expression* expression, const Interval& interval,
                  const char* type);

}  // namespace scram::mef

Updated on 2025-11-11 at 16:51:08 +0000