Skip to content

Commit

Permalink
do not prosecute invalid rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-luecke committed Oct 16, 2024
1 parent b83ef37 commit a48e181
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 106 deletions.
126 changes: 35 additions & 91 deletions src/main/scala/elevate/heuristic_search/heuristics/Exhaustive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,96 +6,68 @@ import elevate.heuristic_search.util.{Solution, hashProgram}

import scala.collection.immutable.Queue

// TODO: rename to breadth first search
class Exhaustive[P] extends Heuristic[P] {

private val allowInvalid: Boolean = false

// todo cleanup
// breadth first
def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = {

println("depth: " + depth)

var counter = 0

var solution = initialSolution
val solutionValue = panel.f(solution)

// craete path
// val path = new Path(solution.expression, solutionValue, null, null, 0)

var queue = Queue.empty[(Int, Solution[P])]
var counter: Int = 0
var solution: Solution[P] = initialSolution
var solutionValue: Option[Double] = panel.f(solution)
var queue: Queue[(Int, Solution[P])] = Queue.empty[(Int, Solution[P])]
queue = queue.enqueue(0, solution)

var i = 0
while (!queue.isEmpty) {
i = i + 1

// println("i: " + i)
// println("queue: " + queue)

// get element from queue
val current = queue.dequeue

// println("current: " + current)

// update current path element
// path.setCurrent(current._1._2)
// todo reach this from start (step by step)
// path.add(current._1._2.program, current._1._2.strategy, current._1._2.value)

// // start at initial node
// var down = path.initial
//
// println("\n")
// println(" --------- go down ---------- ")
// // go down step by step until reaching current program
//
// while (hashProgram(current._1._2.solution.expression) != hashProgram(down.solution.expression)) {
// println("down: " + hashProgram(down.solution.expression))
// println("current: " + hashProgram(current._1._2.solution.expression))
// down = down.successor
// // tmp.program.hashCode() == tmp.successor.program.hashCode()){
// // go one step down
// path.add(down.solution, down.value)
// }
// println(" --------- finished ---------- ")
// println("\n")


// update queue
queue = current._2

// get neighborhood
// execute all elements in the neighborhood and queue them
val Ns = panel.N(current._1._2)

Ns.foreach(ne => {
// path.writePathToDot("/home/jo/development/rise-lang/shine/exploration/dot/mv.dot")
// eval function value

// change this value!

// todo make this configurable option!
val layer = current._1._1 + 1

if (counter < samples) {

// execute
val fne = panel.f(ne)
counter += 1

// check result, minimum and queue accordingly
fne match {
// allow to enqueue invalid results
case None => // don't enqueue

// we don't know the predecessors
// no information if current rewrite sequence is invalid?
// invalid: no performance value
case None =>

if (layer < depth) {
queue = queue.enqueue((layer, ne))
// only enqueue if we allow invalid solutions
if (allowInvalid) {
if (layer < depth) {
queue = queue.enqueue((layer, ne))
}
}

// valid: performance value
case Some(candidateValue) => {

// check if a new minimum was found
val update = solutionValue match {
case Some(sValue) =>
candidateValue < sValue match {
case true => (ne, Some(candidateValue))
case false => (solution, Some(sValue))
}
case None => (ne, Some(candidateValue))
}
case Some(_) => {
solution = update._1
solutionValue = update._2

// enqueue if we have not reached the exploration limit
if (layer < depth) {
queue = queue.enqueue((layer, ne))
}

}
}

Expand All @@ -106,38 +78,10 @@ class Exhaustive[P] extends Heuristic[P] {
None
)
}

// val fne = None

// add path element
// path.add(ne, fne)

// add path element and solution to queue

// revert path
// path.add(Solution(current._1._1.expression, current._1._1.strategies ++ Seq(elevate.core.strategies.basic.revert)), current._1._2.value)
})


// println("\n")
// println(" --------- go up ---------- ")
// var up = current._1._2
// while (up.predecessor != null) {
// up = up.predecessor
// path.add(Solution(up.solution.expression, up.solution.strategies ++ Seq(elevate.core.strategies.basic.revert)), up.value)
// }
// println(" --------- finished ---------- ")
// println("\n")
// current._1._2.predecessor match {
// case null => // do nothing
// case _ =>
// go back to parent
// path.add(current._1._2.predecessor.program, elevate.core.strategies.basic.revert, current._1._2.predecessor.value)
// }

}

// last?
// return best found solution
ExplorationResult(
solution,
solutionValue,
Expand Down
88 changes: 80 additions & 8 deletions src/main/scala/elevate/heuristic_search/heuristics/MCTS.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package elevate.heuristic_search.heuristics

import elevate.heuristic_search._
import elevate.heuristic_search.util.{Solution}
import elevate.heuristic_search.util.Solution

import scala.collection.mutable
import scala.util.Random


Expand Down Expand Up @@ -89,27 +90,37 @@ class MCTS[P] extends Heuristic[P] {

// 3. Rollout
// we start the rollout at the current node
var rollout = node.solution
var rollout: (Solution[P], Option[Double]) = (node.solution, None)

var isTerminal: Boolean = false
while (rollout.solutionSteps.count(step => step.strategy != elevate.core.strategies.basic.id[P]) < depth && !isTerminal) {
val actions = panel.N(rollout)
while (!isTerminal && rollout._1.solutionSteps.count(step => step.strategy != elevate.core.strategies.basic.id[P]) < depth) {

val actions = panel.N(rollout._1)
if (actions.nonEmpty) {
rollout = actions(Random.nextInt(actions.size))

// try to consider only valid ones
// rollout performance should be the minimum that was seen during rollout
rollout = choose_valid_solution_randomly(panel = panel, actions = actions, rollout._2)

rollout._1.solutionSteps.foreach(step => println(s"""[${step.strategy}, ${step.location}]"""))

// check if we have a dead end for this rollout
if (rollout._1 == null) {
isTerminal = true
}

} else {
isTerminal = true
}
}


// 4. Backpropagation
val value: Option[Double] = panel.f(rollout)
counter += 1
while (node != null) {
node.visits += 1

// this can be biased by the ranges
val win: Double = value match {
val win: Double = rollout._2 match {
case Some(value) => 1 / value
case None => 0
}
Expand All @@ -118,6 +129,67 @@ class MCTS[P] extends Heuristic[P] {
}
}

def choose_valid_solution_randomly(panel: HeuristicPanel[P], actions: Seq[Solution[P]], minimum: Option[Double]): (Solution[P], Option[Double]) = {

def findSolution(minimum: Option[Double], attempts: Set[Solution[P]]): (Solution[P], Option[Double]) = {
val remainingActions = actions.filterNot(attempts.contains)

if (remainingActions.isEmpty) {
(null.asInstanceOf[Solution[P]], None) // No valid solution found
} else {
val candidate: Solution[P] = remainingActions(Random.nextInt(remainingActions.size))

// get performance of
panel.f(candidate) match {
case Some(value) =>
minimum match {
case None =>
(candidate, Some(value)) // Valid solution found

case Some(minimum_value) =>
value <= minimum_value match {
case true => (candidate, Some(value))
case false => (candidate, Some(minimum_value))
}
}
case None => findSolution(minimum, attempts + candidate) // Add to attempts and recurse
}
}
}

findSolution(minimum = minimum, attempts = Set.empty[Solution[P]])
}

//
// def choose_valid_solution_randomly(panel: HeuristicPanel[P], actions: Seq[Solution[P]]): (Solution[P], Option[Double]) = {
//
// var found_valid: Boolean = false
// var rollout: (Solution[P], Option[Double]) = (null.asInstanceOf[Solution[P]], None)
// val attempts: mutable.Set[Solution[P]] = scala.collection.mutable.Set.empty[Solution[P]]
//
// while (!found_valid) {
//
// val remainingActions = actions.filterNot(attempts.contains)
//
// rollout = remainingActions.isEmpty match {
// case true => rollout
// case false =>
//
// val candidate: Solution[P] = remainingActions(Random.nextInt(remainingActions.size))
// attempts.add(candidate)
//
// rollout = panel.f(solution = candidate) match {
// case Some(value) =>
// found_valid = true
// (candidate, Some(value))
// case None => (candidate, None)
// }
// rollout
// }
// }
// rollout
// }

// return dummy optimized program
ExplorationResult(
solution = initialSolution,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,37 @@ class RandomGraph[P] extends Heuristic[P] {
depthCounter = depth
solution

// chose valid solution randomly from neighborhood
// choose valid solution randomly from neighborhood
case _ =>

// get next element
solution = Ns.apply(random.nextInt(Ns.size))
solutionValue = panel.f(solution)
var foundValid = false

var attempts = scala.collection.mutable.Set.empty[Solution[P]]

while (!foundValid) {

// filter out already visited candidates
val candidates = Ns.filter(sol => !attempts.contains(sol))

// if no candidate left start from root
if (candidates.isEmpty) {
depthCounter = depth
foundValid = true
} else {
// get next element
sampleCounter += 1
solution = candidates.apply(random.nextInt(candidates.size))
solutionValue = panel.f(solution)

if (solutionValue.equals(None)) {
// add attempt
attempts += solution

} else {
foundValid = true
}
}
}
solution
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TabuSearchPlain[P] extends Heuristic[P] {


def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = {

var solution = initialSolution

val random = scala.util.Random
Expand All @@ -22,7 +23,7 @@ class TabuSearchPlain[P] extends Heuristic[P] {
// // add empty foreach layer
// Range(0, depth + 1).foreach(layer => visited.addOne(layer, mutable.HashSet.empty[Seq[Int]]))

var tabuList = scala.collection.mutable.Queue.empty[Seq[RewriteIdentifier[P]]]
var tabuList = scala.collection.mutable.Queue.empty[Seq[RewriteIdentifier[P]]] // can we collapse a sequence of rewrite identifiers?
val tabuListSize = 100

var counter = 0
Expand Down Expand Up @@ -90,7 +91,8 @@ class TabuSearchPlain[P] extends Heuristic[P] {
}
})

// add parent to tabu list
// add parent to tabu list
// allow to go back
tabuList = tabuList.addOne(solution.rewrites())

// update solution
Expand All @@ -104,7 +106,7 @@ class TabuSearchPlain[P] extends Heuristic[P] {
// beset
NsFSorted.take(NsFiltered.size / 4).foreach(elem => tabuList = tabuList.addOne(elem._1.rewrites()))

// worset
// worset?
// NsFSorted.takeRight(NsFiltered.size / 2).foreach(elem => tabuList = tabuList.addOne(elem._1.rewrites()))

// add all elements to tab lust
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/elevate/heuristic_search/util/Solution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ case class Solution[P](
solutionSteps.size
}

override def toString: String = {
solutionSteps.map(step => s"${step.strategy}, ${step.location}").mkString("\n")
}

}

case class RewriteIdentifier[P](
Expand Down

0 comments on commit a48e181

Please sign in to comment.