Skip to content

packages/engine/scram-node/src/zbdd.h

Zero-Suppressed Binary Decision Diagram facilities.

Namespaces

Name
scram
scram::core
scram::core::zbdd

Classes

Name
classscram::core::SetNode <br>Representation of non-terminal nodes in ZBDD.
structscram::core::PairHash <br>Function for hashing a pair of ordered numbers.
structscram::core::TripletHash <br>Functor for hashing triplets of ordered numbers.
classscram::core::Zbdd <br>Zero-Suppressed Binary Decision Diagrams for set manipulations.
classscram::core::Zbdd::const_iterator <br>Iterator over products in a ZBDD container.

Source code

cpp
/*
 * Copyright (C) 2014-2018 Olzhas Rakhimov
 * Copyright (C) 2023 OpenPRA ORG Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


#pragma once

#include <cstdint>

#include <array>
#include <map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include <boost/functional/hash.hpp>
#include <boost/iterator/iterator_facade.hpp>
#include <boost/noncopyable.hpp>

#include "bdd.h"
#include "pdag.h"

namespace scram::core {

class SetNode : public NonTerminal<SetNode> {
 public:
  using NonTerminal::NonTerminal;

  bool minimal() const { return minimal_; }

  void minimal(bool flag) { minimal_ = flag; }

  int max_set_order() const { return max_set_order_; }

  void max_set_order(int order) { max_set_order_ = order; }

  std::int64_t count() const { return count_; }

  void count(std::int64_t number) { count_ = number; }

 private:
  bool minimal_ = false;  
  int max_set_order_ = 0;  
  std::int64_t count_ = 0;  
};

using SetNodePtr = IntrusivePtr<SetNode>;  

struct PairHash {
  std::size_t operator()(const std::pair<int, int>& p) const noexcept {
    return boost::hash_value(p);
  }
};

template <typename Value>
using PairTable = std::unordered_map<std::pair<int, int>, Value, PairHash>;

using Triplet = std::array<int, 3>;  

struct TripletHash {
  std::size_t operator()(const Triplet& triplet) const noexcept {
    return boost::hash_range(triplet.begin(), triplet.end());
  }
};

template <typename Value>
using TripletTable = std::unordered_map<Triplet, Value, TripletHash>;

class Zbdd : private boost::noncopyable {
 public:
  using VertexPtr = IntrusivePtr<Vertex<SetNode>>;  
  using TerminalPtr = IntrusivePtr<Terminal<SetNode>>;  

  class const_iterator
      : public boost::iterator_facade<const_iterator, const std::vector<int>,
                                      boost::forward_traversal_tag> {
    friend class boost::iterator_core_access;

    class module_iterator {
     public:
      module_iterator(const SetNode* node, const Zbdd& zbdd, const_iterator* it,
                      bool sentinel = false)
          : sentinel_(sentinel),
            start_pos_(it->product_.size()),
            end_pos_(start_pos_),
            it_(*it),
            node_(node),
            zbdd_(zbdd) {
        if (!sentinel_) {
          sentinel_ = !GenerateProduct(zbdd_.root());
          end_pos_ = it_.product_.size();
        }
      }

      module_iterator(module_iterator&&) noexcept = default;

      explicit operator bool() const { return !sentinel_; }

      void operator++() {
        if (sentinel_)
          return;
        assert(end_pos_ >= start_pos_ && "Corrupted sentinel.");
        while (start_pos_ != it_.product_.size()) {
          if (!module_stack_.empty() &&
              it_.product_.size() == module_stack_.back().end_pos_) {
            const SetNode* node = module_stack_.back().node_;
            for (++module_stack_.back(); module_stack_.back();
                 ++module_stack_.back()) {
              if (GenerateProduct(node->high()))
                goto outer_break;
            }
            module_stack_.pop_back();
            if (GenerateProduct(node->low()))
              break;

          } else if (GenerateProduct(Pop()->low())) {
            break;
          }
        }
      outer_break:
        end_pos_ = it_.product_.size();
        sentinel_ = start_pos_ == end_pos_;
      }

     private:
      bool GenerateProduct(const VertexPtr& vertex) noexcept {
        if (vertex->terminal())
          return Terminal<SetNode>::Ref(vertex).value();
        if (it_.product_.size() >= it_.zbdd_.settings().limit_order())
          return false;
        const SetNode& node = SetNode::Ref(vertex);
        if (node.module()) {
          module_stack_.emplace_back(
              &node, *zbdd_.modules_.find(node.index())->second, &it_);
          for (; module_stack_.back(); ++module_stack_.back()) {
            if (GenerateProduct(node.high()))
              return true;
          }
          assert(it_.product_.size() == module_stack_.back().start_pos_);
          module_stack_.pop_back();
          return GenerateProduct(node.low());

        } else {
          Push(&node);
          return GenerateProduct(node.high()) || GenerateProduct(Pop()->low());
        }
      }

      const SetNode* Pop() noexcept {
        assert(start_pos_ < it_.product_.size() && "Access beyond the range!");
        const SetNode* leaf = it_.node_stack_.back();
        it_.node_stack_.pop_back();
        it_.product_.pop_back();
        return leaf;
      }

      void Push(const SetNode* set_node) noexcept {
        it_.node_stack_.push_back(set_node);
        it_.product_.push_back(set_node->index());
      }

      bool sentinel_;  
      const int start_pos_;  
      int end_pos_;  
      const_iterator& it_;  
      const SetNode* node_;  
      const Zbdd& zbdd_;  
      std::vector<module_iterator> module_stack_;
    };

   public:
    explicit const_iterator(const Zbdd& zbdd, bool sentinel = false)
        : sentinel_(sentinel), zbdd_(zbdd), it_(nullptr, zbdd, this, sentinel) {
      sentinel_ = !it_;
    }

    const_iterator(const const_iterator& other) noexcept
        : sentinel_(other.sentinel_),
          zbdd_(other.zbdd_),
          it_(nullptr, zbdd_, this, sentinel_) {
      assert(*this == other && "Copy ctor is only for begin/end iterators.");
    }

   private:
    void increment() {
      assert(!sentinel_ && "Incrementing an end iterator.");
      ++it_;
      sentinel_ = !it_;
    }
    bool equal(const const_iterator& other) const {
      assert(!(sentinel_ && !product_.empty()) && "Uncleared products.");
      return sentinel_ == other.sentinel_ && &zbdd_ == &other.zbdd_ &&
             product_ == other.product_;
    }
    const std::vector<int>& dereference() const {
      assert(!sentinel_ && "Dereferencing end iterator.");
      return product_;
    }

    bool sentinel_;  
    const Zbdd& zbdd_;  
    std::vector<int> product_;  
    std::vector<const SetNode*> node_stack_;  
    module_iterator it_;  
  };

  Zbdd(Bdd* bdd, const Settings& settings) noexcept;

  Zbdd(const Pdag* graph, const Settings& settings) noexcept;

  virtual ~Zbdd() noexcept = default;

  void Analyze(const Pdag* graph = nullptr) noexcept;

  const Zbdd& products() const { return *this; }

  auto begin() const { return const_iterator(*this); }
  auto end() const { return const_iterator(*this, /*sentinel=*/true); }

  std::size_t size() const { return std::distance(begin(), end()); }

  bool empty() const { return begin() == end(); }

  bool base() const { return root_ == kBase_; }

 protected:
  explicit Zbdd(const Settings& settings, bool coherent = false,
                int module_index = 0) noexcept;

  const VertexPtr& root() const { return root_; }

  void root(const VertexPtr& vertex) { root_ = vertex; }

  const Settings& settings() const { return kSettings_; }

  const std::map<int, std::unique_ptr<Zbdd>>& modules() const {
    return modules_;
  }

  void Log() noexcept;

  SetNodePtr FindOrAddVertex(int index, const VertexPtr& high,
                             const VertexPtr& low, int order,
                             bool module = false,
                             bool coherent = false) noexcept;

  SetNodePtr FindOrAddVertex(const Gate& gate, const VertexPtr& high,
                             const VertexPtr& low) noexcept;

  template <Connective Type>
  VertexPtr Apply(const VertexPtr& arg_one, const VertexPtr& arg_two,
                  int limit_order) noexcept;

  VertexPtr Apply(Connective type, const VertexPtr& arg_one,
                  const VertexPtr& arg_two, int limit_order) noexcept;

  template <Connective Type>
  VertexPtr Apply(const SetNodePtr& arg_one, const SetNodePtr& arg_two,
                  int limit_order) noexcept;

  VertexPtr EliminateComplements(
      const VertexPtr& vertex,
      std::unordered_map<int, VertexPtr>* wide_results) noexcept;

  void EliminateConstantModules() noexcept;

  VertexPtr Minimize(const VertexPtr& vertex) noexcept;

  int GatherModules(const VertexPtr& vertex, int current_order,
                    std::map<int, std::pair<bool, int>>* modules) noexcept;

  void ApplySubstitutions(
      const std::vector<Pdag::Substitution>& substitutions) noexcept;

  void ClearTables() noexcept {
    and_table_.clear();
    or_table_.clear();
    minimal_results_.clear();
    subsume_table_.clear();
    prune_results_.clear();
  }

  void Freeze() noexcept {
    unique_table_.Release();
    Zbdd::ClearTables();
    and_table_.reserve(0);
    or_table_.reserve(0);
    minimal_results_.reserve(0);
    subsume_table_.reserve(0);
  }

  void JoinModule(int index, std::unique_ptr<Zbdd> container) noexcept {
    assert(!modules_.count(index));
    assert(container->root()->terminal() ||
           SetNode::Ref(container->root()).minimal());
    modules_.emplace(index, std::move(container));
  }

  const TerminalPtr kBase_;  
  const TerminalPtr kEmpty_;  

 private:
  using SetNodeWeakPtr = WeakIntrusivePtr<SetNode>;  
  using ComputeTable = TripletTable<VertexPtr>;  
  using ModuleEntry = std::pair<const int, std::unique_ptr<Zbdd>>;

  Zbdd(const Bdd::Function& module, bool coherent, Bdd* bdd,
       const Settings& settings, int module_index = 0) noexcept;

  Zbdd(const Gate& gate, const Settings& settings) noexcept;

  SetNodePtr FindOrAddVertex(const SetNodePtr& node, const VertexPtr& high,
                             const VertexPtr& low) noexcept;

  VertexPtr GetReducedVertex(const ItePtr& ite, bool complement,
                             const VertexPtr& high,
                             const VertexPtr& low) noexcept;

  VertexPtr GetReducedVertex(const SetNodePtr& node, const VertexPtr& high,
                             const VertexPtr& low) noexcept;

  Triplet GetResultKey(const VertexPtr& arg_one, const VertexPtr& arg_two,
                       int limit_order) noexcept;

  VertexPtr ConvertBdd(const Bdd::VertexPtr& vertex, bool complement,
                       Bdd* bdd_graph, int limit_order,
                       PairTable<VertexPtr>* ites) noexcept;

  VertexPtr ConvertBdd(const ItePtr& ite, bool complement, Bdd* bdd_graph,
                       int limit_order, PairTable<VertexPtr>* ites) noexcept;

  VertexPtr ConvertBddPrimeImplicants(const ItePtr& ite, bool complement,
                                      Bdd* bdd_graph, int limit_order,
                                      PairTable<VertexPtr>* ites) noexcept;

  VertexPtr
  ConvertGraph(const Gate& gate,
               std::unordered_map<int, std::pair<VertexPtr, int>>* gates,
               std::unordered_map<int, const Gate*>* module_gates) noexcept;

  VertexPtr EliminateComplement(const SetNodePtr& node, const VertexPtr& high,
                                const VertexPtr& low) noexcept;

  VertexPtr EliminateConstantModules(
      const VertexPtr& vertex,
      std::unordered_map<int, VertexPtr>* results) noexcept;

  VertexPtr EliminateConstantModule(const SetNodePtr& node,
                                    const VertexPtr& high,
                                    const VertexPtr& low) noexcept;

  VertexPtr Subsume(const VertexPtr& high, const VertexPtr& low) noexcept;

  VertexPtr Prune(const VertexPtr& vertex, int limit_order) noexcept;

  virtual bool IsGate(const SetNode& node) noexcept { return node.module(); }

  bool MayBeUnity(const SetNode& node) noexcept;

  int CountSetNodes(const VertexPtr& vertex) noexcept;

  std::int64_t CountProducts(const VertexPtr& vertex, bool modules) noexcept;

  void ClearMarks(const VertexPtr& vertex, bool modules) noexcept;

  void ClearCounts(const VertexPtr& vertex, bool modules) noexcept;

  void TestStructure(const VertexPtr& vertex, bool modules) noexcept;

  const Settings kSettings_;  
  VertexPtr root_;  
  bool coherent_;  
  int module_index_;  

  UniqueTable<SetNode> unique_table_;

  ComputeTable and_table_;
  ComputeTable or_table_;

  std::unordered_map<int, VertexPtr> minimal_results_;
  PairTable<VertexPtr> subsume_table_;
  PairTable<VertexPtr> prune_results_;

  std::map<int, std::unique_ptr<Zbdd>> modules_;  
  int set_id_;  
};

namespace zbdd {

class CutSetContainer : public Zbdd {
 public:
  CutSetContainer(const Settings& settings, int module_index,
                  int gate_index_bound) noexcept;

  VertexPtr ConvertGate(const Gate& gate) noexcept;

  int GetNextGate() noexcept {
    if (Zbdd::root()->terminal())
      return 0;
    SetNode& node = SetNode::Ref(Zbdd::root());
    return CutSetContainer::IsGate(node) && !node.module() ? node.index() : 0;
  }

  VertexPtr ExtractIntermediateCutSets(int index) noexcept;

  VertexPtr ExpandGate(const VertexPtr& gate_zbdd,
                       const VertexPtr& cut_sets) noexcept;

  void Merge(const VertexPtr& vertex) noexcept;

  void EliminateComplements() noexcept {
    std::unordered_map<int, VertexPtr> wide_results;
    Zbdd::root(Zbdd::EliminateComplements(Zbdd::root(), &wide_results));
  }

  void EliminateConstantModules() noexcept { Zbdd::EliminateConstantModules(); }

  void Minimize() noexcept { Zbdd::root(Zbdd::Minimize(Zbdd::root())); }

  std::map<int, std::pair<bool, int>> GatherModules() noexcept {
    assert(Zbdd::modules().empty() && "Unexpected call with defined modules?!");
    std::map<int, std::pair<bool, int>> modules;
    Zbdd::GatherModules(Zbdd::root(), 0, &modules);
    return modules;
  }

  using Zbdd::JoinModule;  
  using Zbdd::Log;  

 private:
  bool IsGate(const SetNode& node) noexcept override {
    return node.index() > gate_index_bound_;
  }

  int gate_index_bound_;  
};

}  // namespace zbdd

}  // namespace scram::core

Updated on 2025-11-11 at 16:51:09 +0000