ソフトウェア科学特論: ScalaとSATソルバー

Table of Contents

1 概要

本稿では,Scalaで記述したSATソルバーについて説明する. SATおよびSATソルバーについては 命題論理とSAT を参照のこと.

1.1 注意

本Webページ(およびPDF)の作成には Emacs org-mode を用いており, 数式等の表示は MathJax を用いています. IEでは正しく表示されないことがあるため, Firefox, Safari等のWebブラウザでJavaScriptを有効にしてお使いください. また org-info.js を利用しており, 「m」キーをタイプするとinfoモードでの表示になります. 利用できるショートカットは「?」で表示されます.

1.2 参考文献

2 DPLL

命題論理とSAT でDPLLアルゴリズムを実装した 以下のプログラムを示した (prop-sat11.scala).

  1:  package object sat {
  2:    type Literal = Int
  3:  }
  4:  
  5:  package sat {
  6:    case class Assignment(amap: Map[Int,Boolean] = Map.empty) {
  7:      def isDefinedAt(lit: Literal) =
  8:        amap.isDefinedAt(math.abs(lit))
  9:      def apply(lit: Literal) = amap.get(math.abs(lit)) match {
 10:        case None => None
 11:        case Some(v) => Some(lit < 0 ^ v)
 12:      }
 13:      def + (litValue: (Literal,Boolean)) = {
 14:        val (lit, value) = litValue
 15:        if (lit < 0)
 16:          Assignment(amap + (-lit -> ! value))
 17:        else
 18:          Assignment(amap + (lit -> value))
 19:      }
 20:    }
 21:  
 22:    case class Clause(literals: Set[Literal]) {
 23:      def value(assignment: Assignment): Option[Boolean] =
 24:        literals.map(assignment(_)).reduceLeft((x,y) => (x,y) match {
 25:          case (Some(false), Some(false)) => Some(false)
 26:          case (Some(true), _) | (_, Some(true)) => Some(true)
 27:          case _ => None
 28:        })
 29:      def apply(assignment: Assignment) =
 30:        literals.filter(lit => ! assignment.isDefinedAt(lit))
 31:      override def toString =
 32:        literals.mkString("{", ", ", "}")
 33:    }
 34:  
 35:    class DPLL(clauses: Set[Clause]) {
 36:      var debug = false
 37:      def unitPropagation(assignment: Assignment): Assignment = {
 38:        var a: Assignment = assignment
 39:        var change = true
 40:        while (change) {
 41:          change = false
 42:          for (clause <- clauses) {
 43:            if (clause.value(a) != Some(true)) {
 44:              val c = clause(a)
 45:              if (c.size == 1) {
 46:                val lit = c.head
 47:                if (debug) println("Propagate : " + lit + " -> " + true)
 48:                a = a + (lit -> true)
 49:                change = true
 50:              }
 51:            }
 52:          }
 53:        }
 54:        a
 55:      }
 56:      def select(assignment: Assignment): Literal = {
 57:        val clause = clauses.filter(_.value(assignment) != Some(true)).head
 58:        clause(assignment).head
 59:      }
 60:      def decide(assignment: Assignment,
 61:                 lit: Literal, value: Boolean): Option[Assignment] = {
 62:        if (debug) println("Decide : " + lit + " -> " + value)
 63:        solve(assignment + (lit -> value))
 64:      }
 65:      def solve(assignment: Assignment): Option[Assignment] = {
 66:        val a = unitPropagation(assignment)
 67:        if (clauses.exists(_.value(a) == Some(false))) {
 68:          if (debug) println("Backtrack")
 69:          None
 70:        } else if (clauses.forall(_.value(a) == Some(true))) {
 71:          Some(a)
 72:        } else {
 73:          val p = select(a)
 74:          decide(a, p, false) orElse decide(a, p, true)
 75:        }
 76:      }
 77:      def solve: Option[Assignment] =
 78:        solve(Assignment(Map.empty))
 79:    }
 80:  
 81:    object DPLL {
 82:      def parse(file: String): (Int, Set[Clause]) = {
 83:        val clauses = scala.collection.mutable.Set[Clause]()
 84:        var numOfVariables = 0
 85:        for (line <- scala.io.Source.fromFile(file).getLines()) {
 86:          if (line.startsWith("p ")) {
 87:            numOfVariables = line.split("\\s+")(2).toInt
 88:          } else if (line.startsWith("c ") || line.matches("\\s*")) {
 89:          } else if (line.matches("(-?\\d+\\s+)*0")) {
 90:            var lits = Set[Literal]()
 91:            for (lit <- line.split("\\s+").map(_.toInt).dropRight(1))
 92:              lits += lit
 93:            clauses += Clause(lits)
 94:          } else {
 95:            println("CNF file format error in " + file + ": " + line)
 96:          }
 97:        }
 98:        (numOfVariables, clauses.toSet)
 99:      }
100:      def main(args: Array[String]) {
101:        val (numOfVariables, clauses) = parse(args(0))
102:        val solver = new DPLL(clauses)
103:        if (args.size >= 2 && args(1) == "-d") {
104:          println(clauses)
105:          solver.debug = true
106:        }
107:        solver.solve match {
108:          case None =>
109:            println("UNSAT")
110:          case Some(assignment) => {
111:            println("SAT")
112:            for (i <- 1 to numOfVariables) {
113:              assignment(i) match {
114:                case None | Some(true) => print(i)
115:                case Some(false) => print(-i)
116:              }
117:              print(" ")
118:            }
119:            println("0")
120:          }
121:        }
122:      }
123:    }
124:  }

2.1 実行例

SAT型制約ソルバーSugar を用いて, 数独の問題をSAT問題に符号化したものを実行サンプルとして利用する (sudoku-15.cnf)

元の数独の問題は Nikoli: 数独のおためし問題 の「おためし問題20」で, 制約モデル化は 数独パズルをSugar制約ソルバーで解く の方法を用い, Sugar によりSAT問題に符号化した.

実行例は以下の通りであり, Intel Core i5-2540M vPro 2.60GHz x 2 のマシンで100秒近いCPU時間となった.

$ scalac prop-sat11.scala
$ time scala sat.DPLL sudoku-20.cnf
SAT
-1 -2 -3 4 5 6 7 8 -9 ... 0

real    1m33.450s
user    1m36.038s
sys     0m0.220s

一方 MiniSat 2.2 のCPU時間はほぼ0秒である.

2.2 実行時間の計測

そこで Java のプロファイラー機能を利用し, 上記プログラムのどこが遅いのかを計測してみる.

$ JAVA_OPTS="-Xrunhprof:cpu=samples" scala sat.DPLL sudoku-20.cnf

測定結果として java.hprof.txt という名前のファイルが作成される. このファイルの最後のほうに, 以下のようにJavaのメソッド毎の実行時間が記載されている.

CPU SAMPLES BEGIN (total = 9510) Sun Jul  3 21:42:45 2011
rank   self  accum   count trace method
   1 22.37% 22.37%    2127 300226 scala.collection.immutable.HashMap$HashTrieMap.get0
   2 13.42% 35.78%    1276 300221 scala.collection.immutable.HashMap$HashMap1.get0
   3  6.42% 42.21%     611 300211 scala.collection.mutable.AddingBuilder.$plus$eq
   4  5.55% 47.76%     528 300233 scala.collection.immutable.HashMap$HashMap1.get0
   5  4.97% 52.73%     473 300193 scala.collection.immutable.ListSet$Node.elem
   6  4.05% 56.78%     385 300247 scala.collection.TraversableLike$class.map
   7  3.91% 60.69%     372 300261 scala.collection.immutable.HashMap$HashTrieMap.get0
   8  3.35% 64.05%     319 300256 scala.collection.TraversableLike$class.map
   9  3.21% 67.26%     305 300227 scala.collection.TraversableOnce$$anonfun$reduceLeft$1.apply
  10  2.72% 69.98%     259 300266 scala.collection.immutable.ListSet$$anon$1.next
   .....
CPU SAMPLES END

これを見ると Map 関係のメソッドで多くの時間を消費していることがわかる.

2.3 真理値割当を配列に変更

上のプログラムでは真理値割当を Map で表現しており, これが多くの時間を消費していた.

そこで,命題変数を正の整数で表現し, 真理値割当を命題変数を添字とした配列 (Array)で表現することにする. また命題変数が正の整数なので,リテラルを整数で表現する(負リテラルが負数).

 1:  package object sat {
 2:    type Literal = Int
 3:  }
 4:  
 5:  package sat {
 6:    class Assignment extends (Literal => Option[Boolean]) {
 7:      def this(numOfVariables: Int) = {
 8:        this()
 9:        assignment = new Array(numOfVariables + 1)
10:        for (i <- 0 to numOfVariables)
11:          assignment(i) = None
12:      }
13:      var assignment: Array[Option[Boolean]] = null
14:      def size = assignment.size - 1
15:      def isDefinedAt(lit: Literal) =
16:        assignment(math.abs(lit)).isDefined
17:      def apply(lit: Literal) = 
18:        if (lit < 0) assignment(-lit).map(! _) else assignment(lit)
19:      def update(lit: Literal, value: Boolean) =
20:        if (lit < 0)
21:          assignment(-lit) = Some(! value)
22:        else
23:          assignment(lit) = Some(value)
24:      def clear(lit: Literal) =
25:        assignment(math.abs(lit)) = None
26:      override def toString = {
27:        val a = for {
28:          i <- 1 to size if assignment(i).isDefined
29:        } yield (i, assignment(i).get)
30:        a.toMap.mkString("{", ", ", "}")
31:      }
32:    }
33:  }

同様に,節も整数であるリテラルの集合として定義しなおす (package sat中に記述する).

 1:  case class Clause(literals: Set[Literal]) {
 2:    def value(assignment: Assignment): Option[Boolean] =
 3:      literals.map(assignment(_)).reduceLeft((x,y) => (x,y) match {
 4:        case (Some(false), Some(false)) => Some(false)
 5:        case (Some(true), _) | (_, Some(true)) => Some(true)
 6:        case _ => None
 7:      })
 8:    def apply(assignment: Assignment) =
 9:      literals.filter(lit => ! assignment.isDefinedAt(lit))
10:    override def toString =
11:      literals.mkString("{", ", ", "}")
12:  }

前のプログラムでは, 真理値割当を immutable な Map で表現しており, バックトラックの処理が簡単だったが, 今回は mutable なデータ構造である Array のため, バックトラック時に Array 中の値を元に戻す必要がある.

そこで trail スタックと呼ばれる ArrayStack を用意し, どの命題変数に対する真理値割当だったかを記録する. バックトラック時にはそれらの命題変数に対する真理値割当を 削除する(Noneを代入する).

修正したプログラムは以下のようになる (prop-sat13.scala).

  1:  import scala.collection.mutable.ArrayStack
  2:  
  3:  package object sat {
  4:    type Literal = Int
  5:  }
  6:  
  7:  package sat {
  8:    class Assignment extends (Literal => Option[Boolean]) {
  9:      def this(numOfVariables: Int) = {
 10:        this()
 11:        assignment = new Array(numOfVariables + 1)
 12:        for (i <- 0 to numOfVariables)
 13:          assignment(i) = None
 14:      }
 15:      var assignment: Array[Option[Boolean]] = null
 16:      def size = assignment.size - 1
 17:      def isDefinedAt(lit: Literal) =
 18:        assignment(math.abs(lit)).isDefined
 19:      def apply(lit: Literal) = 
 20:        if (lit < 0) assignment(-lit).map(! _) else assignment(lit)
 21:      def update(lit: Literal, value: Boolean) =
 22:        if (lit < 0)
 23:          assignment(-lit) = Some(! value)
 24:        else
 25:          assignment(lit) = Some(value)
 26:      def clear(lit: Literal) =
 27:        assignment(math.abs(lit)) = None
 28:      override def toString = {
 29:        val a = for {
 30:          i <- 1 to size if assignment(i).isDefined
 31:        } yield (i, assignment(i).get)
 32:        a.toMap.mkString("{", ", ", "}")
 33:      }
 34:    }
 35:  
 36:    case class Clause(literals: Set[Literal]) {
 37:      def value(assignment: Assignment): Option[Boolean] =
 38:        literals.map(assignment(_)).reduceLeft((x,y) => (x,y) match {
 39:          case (Some(false), Some(false)) => Some(false)
 40:          case (Some(true), _) | (_, Some(true)) => Some(true)
 41:          case _ => None
 42:        })
 43:      def apply(assignment: Assignment) =
 44:        literals.filter(lit => ! assignment.isDefinedAt(lit))
 45:      override def toString =
 46:        literals.mkString("{", ", ", "}")
 47:    }
 48:  
 49:    class DPLL(numOfVariables: Int, clauses: Set[Clause]) {
 50:      var debug = false
 51:      val assignment = new Assignment(numOfVariables)
 52:      val trail: ArrayStack[Literal] = new ArrayStack[Literal]()
 53:      def assign(lit: Literal, value: Boolean) {
 54:        assignment(lit) = value
 55:        trail.push(lit)
 56:      }
 57:      def value(lit: Literal) =
 58:        assignment(lit)
 59:      def undo(size: Int) {
 60:        while (trail.size > size) {
 61:          val lit = trail.pop
 62:          assignment.clear(lit)
 63:        }
 64:      }
 65:      def unitPropagation {
 66:        var change = true
 67:        while (change) {
 68:          change = false
 69:          for (clause <- clauses) {
 70:            if (clause.value(assignment) != Some(true)) {
 71:              val c = clause(assignment)
 72:              if (c.size == 1) {
 73:                val lit = c.head
 74:                if (debug) println("Propagate : " + lit + " -> " + true)
 75:                assign(lit, true)
 76:                change = true
 77:              }
 78:            }
 79:          }
 80:        }
 81:      }
 82:      def select: Literal =
 83:        (1 to numOfVariables).find(! assignment.isDefinedAt(_)).get
 84:      def decide(lit: Literal, value: Boolean): Boolean = {
 85:        if (debug) println("Decide : " + lit + " -> " + value)
 86:        assign(lit, value)
 87:        solve
 88:      }
 89:      def solve: Boolean = {
 90:        val n = trail.size
 91:        unitPropagation
 92:        if (clauses.exists(_.value(assignment) == Some(false))) {
 93:          if (debug) println("Backtrack")
 94:          undo(n)
 95:          false
 96:          true
 97:        } else if (clauses.forall(_.value(assignment) == Some(true))) {
 98:          true
 99:        } else {
100:          val lit = select
101:          if (decide(lit, false) || decide(lit, true)) {
102:            true
103:          } else {
104:            undo(n)
105:            false
106:          }
107:        }
108:      }
109:    }
110:    object DPLL {
111:      def parse(file: String): (Int, Set[Clause]) = {
112:        val clauses = scala.collection.mutable.Set[Clause]()
113:        var numOfVariables = 0
114:        for (line <- scala.io.Source.fromFile(file).getLines()) {
115:          if (line.startsWith("p ")) {
116:            numOfVariables = line.split("\\s+")(2).toInt
117:          } else if (line.startsWith("c ") || line.matches("\\s*")) {
118:          } else if (line.matches("(-?\\d+\\s+)*0")) {
119:            var lits = Set[Literal]()
120:            for (lit <- line.split("\\s+").map(_.toInt).dropRight(1))
121:              lits += lit
122:            clauses += Clause(lits)
123:          } else {
124:            println("CNF file format error in " + file + ": " + line)
125:          }
126:        }
127:        (numOfVariables, clauses.toSet)
128:      }
129:      def main(args: Array[String]) {
130:        val (numOfVariables, clauses) = parse(args(0))
131:        val solver = new DPLL(numOfVariables, clauses)
132:        if (args.size >= 2 && args(1) == "-d") {
133:          println(clauses)
134:          solver.debug = true
135:        }
136:        if (solver.solve) {
137:            println("SAT")
138:            for (i <- 1 to numOfVariables) {
139:              solver.value(i) match {
140:                case None | Some(true) => print(i)
141:                case Some(false) => print(-i)
142:              }
143:              print(" ")
144:            }
145:            println("0")
146:        } else {
147:          println("UNSAT")
148:        }
149:      }
150:    }
151:  }

2.3.1 実行時間の計測

実行時間を計測してみると,10倍以上の速度向上結果となった.

$ scalac prop-sat13.scala
$ time scala sat.DPLL sudoku-20.cnf
SAT
-1 -2 -3 4 5 6 7 8 -9 ... 0

real    0m6.421s
user    0m7.908s
sys     0m0.136s

再びプロファイラー機能を利用し, 上記プログラムの遅い部分を計測してみる.

$ JAVA_OPTS="-Xrunhprof:cpu=samples" scala sat.DPLL sudoku-20.cnf
CPU SAMPLES BEGIN (total = 660) Sun Jul  3 22:17:20 2011
rank   self  accum   count trace method
   1 41.82% 41.82%     276 300195 scala.collection.immutable.ListSet$Node.elem
   2 28.94% 70.76%     191 300196 scala.collection.immutable.ListSet$Node.elem
   3  2.73% 73.48%      18 300049 java.lang.ClassLoader.defineClass1
   4  0.91% 74.39%       6 300130 scala.collection.immutable.Set$Set2.iterator
   5  0.91% 75.30%       6 300151 java.util.Arrays.copyOfRange
   6  0.76% 76.06%       5 300055 java.util.zip.Inflater.inflateBytes
   7  0.76% 76.82%       5 300173 scala.collection.immutable.HashSet$HashTrieSet.updated0
   8  0.76% 77.58%       5 300211 scala.collection.TraversableLike$$anonfun$map$1.apply
   9  0.61% 78.18%       4 300146 java.lang.Integer.hashCode
   .....
CPU SAMPLES END

今度は Set 処理に時間がかかっているようだ. これは節集合の処理に関する部分と思われる.

節集合を Set[Clause] ではなく List[Clause] にすると prop-sat14.scala となる. 実行時間を計測してみると以下のようになり,さらに約3倍の速度向上が実現できた.

$ scalac prop-sat14.scala
$ time scala sat.DPLL sudoku-20.cnf
SAT
-1 -2 -3 4 5 6 7 8 -9 ... 0

real    0m1.545s
user    0m2.640s
sys     0m0.096s

3 TODO 近代SATソルバーでの技術

  • CDCL (Conflict Driven Clause Learning)
  • Backjumping
  • Two literal watching
  • Random restart

Date:

Author: 田村直之

Org version 7.8.02 with Emacs version 24

Validate XHTML 1.0