Skip to content

packages/engine/scram-node/src/bdd.h

Fault tree analysis with the Binary Decision Diagram algorithms.

Namespaces

Name
scram
scram::core

Classes

Name
structscram::core::IntrusivePtrCast <br>Provides pointer and reference cast wrappers for intrusive Vertex pointers.
classscram::core::WeakIntrusivePtr <br>A weak pointer to store in BDD unique tables.
classscram::core::Vertex <br>Representation of a vertex in BDD graphs.
classscram::core::Terminal <br>Representation of terminal vertices in BDD graphs.
classscram::core::NonTerminal <br>Representation of non-terminal vertices in BDD graphs.
classscram::core::Ite <br>Representation of non-terminal if-then-else vertices in BDD graphs.
classscram::core::UniqueTable <br>A hash table for keeping BDD reduced.
classscram::core::CacheTable <br>A hash table without collision resolution.
classscram::core::Bdd <br>Analysis of PDAGs with Binary Decision Diagrams.
structscram::core::Bdd::Function <br>Holder of computation resultant functions and gate representations.
classscram::core::Bdd::Consensus <br>Provides access to consensus calculation private facilities.

Source code

cpp
/*
 * Copyright (C) 2014-2018 Olzhas Rakhimov
 * Copyright (C) 2023 OpenPRA ORG Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


#pragma once

#include <cmath>

#include <algorithm>
#include <forward_list>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include <boost/functional/hash.hpp>
#include <boost/noncopyable.hpp>
#include <boost/smart_ptr/intrusive_ptr.hpp>

#include "pdag.h"
#include "settings.h"

namespace scram::core {

template <class T>
using IntrusivePtr = boost::intrusive_ptr<T>;

template <class T>
class Vertex;  // Manager of its own entry in the unique table.

template <class T, class W = T>
struct IntrusivePtrCast {
  static IntrusivePtr<W> Ptr(const IntrusivePtr<Vertex<T>>& vertex) {
    return boost::static_pointer_cast<W>(vertex);
  }
  static W& Ref(const IntrusivePtr<Vertex<T>>& vertex) {
    return static_cast<W&>(*vertex);
  }
};

template <class T>
class WeakIntrusivePtr final : private boost::noncopyable {
  friend class Vertex<T>;  // Communicates the destruction of the vertex.

 public:
  WeakIntrusivePtr() noexcept : vertex_(nullptr) {}

  explicit WeakIntrusivePtr(const IntrusivePtr<T>& ptr) noexcept
      : vertex_(ptr.get()) {
    assert(vertex_->table_ptr_ == nullptr && "Non-unique table pointers.");
    vertex_->table_ptr_ = this;
  }

  WeakIntrusivePtr& operator=(const IntrusivePtr<T>& ptr) noexcept {
    this->~WeakIntrusivePtr();
    new (this) WeakIntrusivePtr(ptr);
    return *this;
  }

  ~WeakIntrusivePtr() noexcept {
    if (vertex_)
      vertex_->table_ptr_ = nullptr;
  }

  bool expired() const { return !vertex_; }

  IntrusivePtr<T> lock() const { return IntrusivePtr<T>(vertex_); }

  T* get() const { return vertex_; }

 private:
  T* vertex_;  
};

template <class T>
class Terminal;  // Forward declaration for Vertex to manage.

template <class T>
class Vertex : private boost::noncopyable {
  friend class WeakIntrusivePtr<T>;  // Mutual friendship to manage table entry.

  friend void intrusive_ptr_add_ref(Vertex<T>* ptr) noexcept {
    ptr->use_count_++;
  }

  friend void intrusive_ptr_release(Vertex<T>* ptr) noexcept {
    assert(ptr->use_count_ > 0 && "Missing reference counts.");
    if (--ptr->use_count_ == 0) {
      if (!ptr->terminal()) {  // Likely.
        delete static_cast<T*>(ptr);
      } else {
        delete static_cast<Terminal<T>*>(ptr);
      }
    }
  }

 public:
  explicit Vertex(int id) : id_(id), use_count_(0), table_ptr_(nullptr) {}

  int id() const { return id_; }

  bool terminal() const { return id_ < 2; }

  int use_count() const { return use_count_; }

  bool unique() const {
    assert(use_count_ && "No registered shared pointers.");
    return use_count_ == 1;
  }

 protected:
  ~Vertex() noexcept {
    if (table_ptr_)
      table_ptr_->vertex_ = nullptr;
  }

 private:
  int id_;  
  int use_count_;  
  WeakIntrusivePtr<T>* table_ptr_;  
};

template <class T>
class Terminal : public Vertex<T>, public IntrusivePtrCast<T, Terminal<T>> {
 public:
  explicit Terminal(bool value) : Vertex<T>(value) {}

  bool value() const { return Vertex<T>::id(); }
};

template <class T>
class NonTerminal : public Vertex<T>, public IntrusivePtrCast<T> {
  using VertexPtr = IntrusivePtr<Vertex<T>>;  

  friend int get_high_id(const NonTerminal<T>& vertex) noexcept {
    return vertex.high_->id();
  }
  friend int get_low_id(const NonTerminal<T>& vertex) noexcept {
    return vertex.low_->id();
  }

 public:
  NonTerminal(int index, int order, int id, const VertexPtr& high,
              const VertexPtr& low)
      : Vertex<T>(id),
        high_(high),
        low_(low),
        order_(order),
        index_(index),
        module_(false),
        coherent_(false),
        mark_(false) {}

  int index() const { return index_; }

  int order() const {
    assert(order_ > 0);
    return order_;
  }

  bool module() const { return module_; }

  void module(bool flag) { module_ = flag; }

  bool coherent() const { return coherent_; }

  void coherent(bool flag) {
    assert(!(coherent_ && !flag) && "Inverting existing coherence.");
    coherent_ = flag;
  }

  const VertexPtr& high() const { return high_; }

  const VertexPtr& low() const { return low_; }

  bool mark() const { return mark_; }

  void mark(bool flag) { mark_ = flag; }

 protected:
  ~NonTerminal() = default;

 private:
  VertexPtr high_;  
  VertexPtr low_;  
  int order_;  
  int index_;  
  bool module_;  
  bool coherent_;  
  bool mark_;  
};

class Ite : public NonTerminal<Ite> {
  friend int get_low_id(const Ite& ite) noexcept {
    return ite.complement_edge_ ? -ite.low()->id() : ite.low()->id();
  }

 public:
  using NonTerminal::NonTerminal;

  bool complement_edge() const { return complement_edge_; }

  void complement_edge(bool flag) { complement_edge_ = flag; }

  double p() const { return p_; }

  void p(double value) { p_ = value; }

  double factor() const { return factor_; }

  void factor(double value) { factor_ = value; }

 private:
  bool complement_edge_ = false;  
  double p_ = 0;  
  double factor_ = 0;  
};

using ItePtr = IntrusivePtr<Ite>;  

int GetPrimeNumber(int n);

template <class T>
class UniqueTable {
  using Bucket = std::forward_list<WeakIntrusivePtr<T>>;
  using Table = std::vector<Bucket>;

 public:
  explicit UniqueTable(int init_capacity = 1000)
      : capacity_(core::GetPrimeNumber(init_capacity)),
        size_(0),
        max_load_factor_(0.75),
        table_(capacity_) {}

  int size() const { return size_; }

  void clear() {
    for (Bucket& chain : table_)
      chain.clear();
    size_ = 0;
  }

  //
  void Release() { table_ = Table(); }

  WeakIntrusivePtr<T>& FindOrAdd(int index, int high_id, int low_id) noexcept {
    if (size_ >= (max_load_factor_ * capacity_))
      Rehash(GetNextCapacity(capacity_));

    int bucket_number = Hash(index, high_id, low_id) % capacity_;
    Bucket& chain = table_[bucket_number];
    auto it_prev = chain.before_begin();  // Parent.
    for (auto it_cur = chain.begin(), it_end = chain.end(); it_cur != it_end;) {
      if (it_cur->expired()) {
        it_cur = chain.erase_after(it_prev);
        --size_;
      } else {
        T* vertex = it_cur->get();
        if (index == vertex->index() && high_id == get_high_id(*vertex) &&
            low_id == get_low_id(*vertex)) {
          return *it_cur;
        }
        it_prev = it_cur;
        ++it_cur;
      }
    }
    ++size_;
    return *chain.emplace_after(it_prev);
  }

 private:
  void Rehash(int new_capacity) {
    int new_size = 0;
    Table new_table(new_capacity);
    for (Bucket& chain : table_) {
      for (auto it_prev = chain.before_begin(), it_cur = chain.begin(),
                it_end = chain.end();
           it_cur != it_end;) {
        if (it_cur->expired()) {
          it_prev = it_cur;
          ++it_cur;
          continue;
        }
        ++new_size;
        T* vertex = it_cur->get();
        int bucket_number =
            Hash(vertex->index(), get_high_id(*vertex), get_low_id(*vertex)) %
            new_capacity;
        Bucket& new_chain = new_table[bucket_number];
        new_chain.splice_after(new_chain.before_begin(), chain, it_prev,
                               ++it_cur);
      }
    }
    table_.swap(new_table);
    size_ = new_size;
    capacity_ = new_capacity;
  }

  std::size_t Hash(int index, int high_id, int low_id) {
    std::size_t seed = 0;
    boost::hash_combine(seed, index);
    boost::hash_combine(seed, high_id);
    boost::hash_combine(seed, low_id);
    return seed;
  }

  int GetNextCapacity(int prev_capacity) {
    const int kMaxScaleCapacity = 1e8;
    int scale_power = 1;  // The default power after the max scale capacity.
    if (prev_capacity < kMaxScaleCapacity) {
      scale_power += std::log10(kMaxScaleCapacity / prev_capacity);
    }
    int growth_factor = std::pow(2, scale_power);
    int new_capacity = prev_capacity * growth_factor;
    return core::GetPrimeNumber(new_capacity);
  }

  int capacity_;  
  int size_;  
  double max_load_factor_;  

  Table table_;
};

template <class V>
class CacheTable {
 public:
  using key_type = std::pair<int, int>;
  using mapped_type = V;
  using value_type = std::pair<key_type, mapped_type>;
  using container_type = std::vector<value_type>;
  using iterator = typename container_type::iterator;

  explicit CacheTable(int init_capacity = 1000)
      : size_(0),
        max_load_factor_(0.75),
        table_(core::GetPrimeNumber(init_capacity)) {}

  int size() const { return size_; }

  void clear() {
    for (value_type& entry : table_) {
      if (entry.second)
        entry.second.reset();
    }
    size_ = 0;
  }

  void reserve(int n) {
    if (size_ == 0 && n == 0) {
      table_ = decltype(table_)();
      return;
    }
    if (n <= size_)
      return;
    Rehash(core::GetPrimeNumber(n / max_load_factor_ + 1));
  }

  iterator find(const key_type& key) {
    int index = boost::hash_value(key) % table_.size();
    value_type& entry = table_[index];
    if (!entry.second || entry.first != key)
      return table_.end();
    return table_.begin() + index;
  }

  iterator end() { return table_.end(); }

  void emplace(const key_type& key, const mapped_type& value) {
    assert(value && "Empty computation results!");

    if (size_ >= (max_load_factor_ * table_.size()))
      Rehash(core::GetPrimeNumber(table_.size() * 2));

    int index = boost::hash_value(key) % table_.size();
    value_type& entry = table_[index];
    if (!entry.second)
      ++size_;
    entry.first = key;  // Key equality is unlikely for the use case.
    entry.second = value;  // Might be purging another value.
  }

 private:
  void Rehash(int new_capacity) {
    int new_size = 0;
    std::vector<value_type> new_table(new_capacity);
    for (value_type& entry : table_) {
      if (!entry.second)
        continue;
      int new_index = boost::hash_value(entry.first) % new_table.size();
      value_type& new_entry = new_table[new_index];
      new_entry.first = entry.first;
      if (!new_entry.second)
        ++new_size;
      new_entry.second.swap(entry.second);
    }
    size_ = new_size;
    table_.swap(new_table);
  }

  int size_;  
  double max_load_factor_;  
  std::vector<value_type> table_;  
};

class Zbdd;  // For analysis purposes.

class Bdd : private boost::noncopyable {
 public:
  using VertexPtr = IntrusivePtr<Vertex<Ite>>;  
  using TerminalPtr = IntrusivePtr<Terminal<Ite>>;  

  struct Function {
    bool complement;  
    VertexPtr vertex;  

    explicit operator bool() const { return vertex != nullptr; }

    void reset() { vertex = nullptr; }

    void swap(Function& other) noexcept {
      std::swap(complement, other.complement);
      vertex.swap(other.vertex);
    }
  };

  class Consensus {
    friend class Zbdd;  // Access for calculation of prime implicants.

    Function operator()(Bdd* bdd, const ItePtr& ite, bool complement) noexcept {
      return bdd->CalculateConsensus(ite, complement);
    }
  };

  Bdd(const Pdag* graph, const Settings& settings);

  ~Bdd() noexcept;

  const Function& root() const { return root_; }

  const std::unordered_map<int, Function>& modules() const { return modules_; }

  const std::unordered_map<int, int>& index_to_order() const {
    return index_to_order_;
  }

  bool coherent() const { return coherent_; }

  void ClearMarks(bool mark) { ClearMarks(root_.vertex, mark); }

  void Analyze(const Pdag* graph = nullptr) noexcept;

  const Zbdd& products() const {
    assert(zbdd_ && "Analysis is not done.");
    return *zbdd_;
  }

 private:
  using IteWeakPtr = WeakIntrusivePtr<Ite>;  
  using ComputeTable = CacheTable<Function>;  

  ItePtr FindOrAddVertex(int index, const VertexPtr& high, const VertexPtr& low,
                         bool complement_edge, int order) noexcept;

  ItePtr FindOrAddVertex(const ItePtr& ite, const VertexPtr& high,
                         const VertexPtr& low, bool complement_edge) noexcept;

  ItePtr FindOrAddVertex(const Gate& gate, const VertexPtr& high,
                         const VertexPtr& low, bool complement_edge) noexcept;

  Function ConvertGraph(
      const Gate& gate,
      std::unordered_map<int, std::pair<Function, int>>* gates) noexcept;

  std::pair<int, int> GetMinMaxId(const VertexPtr& arg_one,
                                  const VertexPtr& arg_two, bool complement_one,
                                  bool complement_two) noexcept;

  template <Connective Type>
  Function Apply(const VertexPtr& arg_one, const VertexPtr& arg_two,
                 bool complement_one, bool complement_two) noexcept;

  template <Connective Type>
  Function Apply(ItePtr ite_one, ItePtr ite_two, bool complement_one,
                 bool complement_two) noexcept;

  Function Apply(Connective type, const VertexPtr& arg_one,
                 const VertexPtr& arg_two, bool complement_one,
                 bool complement_two) noexcept;

  Function CalculateConsensus(const ItePtr& ite, bool complement) noexcept;

  int CountIteNodes(const VertexPtr& vertex) noexcept;

  void ClearMarks(const VertexPtr& vertex, bool mark) noexcept;

  void TestStructure(const VertexPtr& vertex) noexcept;

  void ClearTables() noexcept {
    and_table_.clear();
    or_table_.clear();
  }

  void Freeze() noexcept {
    unique_table_.Release();
    ClearTables();
    and_table_.reserve(0);
    or_table_.reserve(0);
  }

  const Settings kSettings_;  
  Function root_;  
  bool coherent_;  

  UniqueTable<Ite> unique_table_;

  ComputeTable and_table_;
  ComputeTable or_table_;

  std::unordered_map<int, Function> modules_;  
  std::unordered_map<int, int> index_to_order_;  
  const TerminalPtr kOne_;  
  int function_id_;  
  std::unique_ptr<Zbdd> zbdd_;  
};

}  // namespace scram::core

Updated on 2025-11-11 at 16:51:08 +0000