Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end

### Initialization

In order to start implementing the core parts of our algorithm, we start at the very beginning.
In order to begin implementing the core parts of our algorithm, we start at the very beginning.
There are two main entry points provided by the interface:

- [`initialize_state`](@ref) constructs an entirely new state for the algorithm
Expand All @@ -76,21 +76,21 @@ There are two main entry points provided by the interface:
An example implementation might look like:

```@example Heron
function AlgorithmsInterface.initialize_state(problem::SqrtProblem, algorithm::HeronAlgorithm; kwargs...)
x0 = rand() # random initial guess
stopping_criterion_state = initialize_state(problem, algorithm, algorithm.stopping_criterion)
return HeronState(x0, 0, stopping_criterion_state)
function AlgorithmsInterface.initialize_state(
problem::SqrtProblem, algorithm::HeronAlgorithm,
stopping_criterion_state::StoppingCriterionState;
kwargs...
)
x0 = rand()
Comment on lines +79 to +84

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a small change, but in practice what I've been doing is writing the definition of initialize_state more like this:

function AlgorithmsInterface.initialize_state(
        problem::SqrtProblem, algorithm::HeronAlgorithm,
        stopping_criterion_state::StoppingCriterionState;
        x0 = rand()
    )

Maybe I was being dense, but it took me some time to realize that was a "valid" way to define it and still have control over the parts of the state initialization that I wanted to control, since that allows running solve as:

solve(SqrtProblem(16.0), HeronAlgorithm(StopAfterIteration(10)); x0 = 1.0)

I think seeing the example written that way would have helped me understand how I should define initialize_state.

However, I realized a slightly awkward thing about defining initialize_state in that way and then calling solve as solve(SqrtProblem(16.0), HeronAlgorithm(StopAfterIteration(10)); x0 = 1.0) is that then x0 gets passed to both initialize_state and initialize_state!. Maybe that's not a problem in practice but I found it to be a bit strange (i.e. should initialize_state! use x0 or not?), and also made me confused about the roles of intialize_state vs. initialize_state!. I think the suggestion in this PR of just having initialize_state! reset iteration helps clarify that issue a bit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about the kwargs is a really good one, I was looking at this a bit before and it's somewhat difficult to come up with a way to split arbitrary keywords generically, but it could even make sense to just not pass the keyword arguments to initialize_state! anymore if they have already been passed to initialize_state?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about that, indeed maybe it makes sense to not pass the keyword arguments to initialize_state!. That could make the distinction between initialize_state and initialize_state! clearer, since then initialize_state handles external inputs (i.e. initial guesses for the iterate) while initialize_state! just handles resetting the "internal" state, such as the iteration number and stopping criterion state. Calling solve! means that you made the state already, so it seems like anything handled through keyword arguments to initialize_state! could instead be handled by just modifying the state directly (i.e. before calling solve!/initialize_state!).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I disagree here, for me functions of same name should do the same and accept the same keywords. Or to be more precise

initialize_state should create a state from new memory but also
initialize_state! applied to the same memory state should set it to the same situation as well.

iteration = 0
return HeronState(x0, iteration, stopping_criterion_state)
end

function AlgorithmsInterface.initialize_state!(problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState; kwargs...)
# reset the state for the algorithm
state.iterate = rand()
state.iteration = 0

# reset the state for the stopping criterion
state = AlgorithmsInterface.initialize_state!(
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
function AlgorithmsInterface.initialize_state!(
problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState;
kwargs...
)
state.iteration = 0
return state
end
```
Expand Down Expand Up @@ -175,6 +175,15 @@ Order = [:type, :function]
Private = true
```

### Stopping Criteria

```@autodocs
Modules = [AlgorithmsInterface]
Pages = ["interface/stopping.jl"]
Order = [:type, :function]
Private = true
```

### Next: Stopping criteria

Proceed to the stopping criteria section to add robust halting logic (iteration caps, time limits, tolerance on successive iterates, and combinations) to this square‑root example.
15 changes: 10 additions & 5 deletions docs/src/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,21 @@ mutable struct HeronState <: State
stopping_criterion_state
end

function AlgorithmsInterface.initialize_state(problem::SqrtProblem, algorithm::HeronAlgorithm; kwargs...)
function AlgorithmsInterface.initialize_state(
problem::SqrtProblem, algorithm::HeronAlgorithm,
stopping_criterion_state::StoppingCriterionState;
kwargs...
)
x0 = rand()
stopping_criterion_state = initialize_state(problem, algorithm, algorithm.stopping_criterion)
iteration = 0
return HeronState(x0, 0, stopping_criterion_state)
end

function AlgorithmsInterface.initialize_state!(problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState; kwargs...)
state.iterate = rand()
function AlgorithmsInterface.initialize_state!(
problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState;
kwargs...
)
state.iteration = 0
initialize_state!(problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state)
return state
end

Expand Down
163 changes: 118 additions & 45 deletions docs/src/stopping_criterion.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,31 @@ Here, we delve a bit deeper into the core components of what made our algorithm
### Initialization

The first core component to enable working with stopping criteria is to extend the initialization step to include initializing a [`StoppingCriterionState`](@ref) as well.
This can conveniently be done through the same initialization functions we used for initializing the state:

- [`initialize_state`](@ref) constructs an entirely new stopping state for the algorithm
- [`initialize_state!`](@ref) (in-place) reset of an existing stopping state.
Since some of these may require _stateful_ implementations, we also keep a `stopping_criterion_state` that captures this, and thus needs to be initialized.
By default, the initialization happens automatically and the only thing that is left for us to do is to attach this `stopping_criterion_state` to the `state` in the [`initialize_state`](@ref) function, as we already saw before:

```@example Heron
function AlgorithmsInterface.initialize_state(problem::SqrtProblem, algorithm::HeronAlgorithm; kwargs...)
x0 = rand() # random initial guess
stopping_criterion_state = initialize_state(problem, algorithm, algorithm.stopping_criterion)
function AlgorithmsInterface.initialize_state(
problem::SqrtProblem, algorithm::HeronAlgorithm,
stopping_criterion_state::StoppingCriterionState;
kwargs...
)
x0 = rand()
iteration = 0
return HeronState(x0, 0, stopping_criterion_state)
end

function AlgorithmsInterface.initialize_state!(problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState; kwargs...)
# reset the state for the algorithm
state.iterate = rand()
state.iteration = 0

# reset the state for the stopping criterion
state = AlgorithmsInterface.initialize_state!(
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
function AlgorithmsInterface.initialize_state!(
problem::SqrtProblem, algorithm::HeronAlgorithm, state::HeronState;
kwargs...
)
state.iteration = 0
return state
end
```

Note that we do not need to handle any stopping criteria in the [`initialize_state!`](@ref) function, as a separate call to [`AlgorithmsInterface.initialize_stopping_state!`](@ref) is made independently.

### Iteration

During the iteration procedure, as set out by our design principles, we do not have to modify any of the code, and the stopping criteria do not show up:
Expand Down Expand Up @@ -145,9 +145,75 @@ heron_sqrt(2; stopping_criterion = criterion)
## Implementing a new criterion

It is of course possible that we are not satisfied by the stopping criteria that are provided by default.
Suppose we want to stop when successive iterates change by less than `ϵ`, we could achieve this by implementing our own stopping criterion.
In order to do so, we need to define our own structs and implement the required interface.
Again, we split up the data into a _static_ part, the [`StoppingCriterion`](@ref), and a _dynamic_ part, the [`StoppingCriterionState`](@ref).
For example, we might check for convergence by squaring our current `iterate` and seeing if it equals the input value.
In order to do so, we need to define our own struct and implement the required interface.

```@example Heron
struct StopWhenSquared <: StoppingCriterion
tol::Float64 # when do we consider things to be converged
end
```

### Checking for convergence

Then, we need to implement the logic that checks whether an algorithm has finished, which is achieved through [`is_finished`](@ref) and [`is_finished!`](@ref).

```@example Heron
using AlgorithmsInterface: DefaultStoppingCriterionState

function AlgorithmsInterface.is_finished(
problem::SqrtProblem, ::Algorithm, state::State,
stopping_criterion::StopWhenSquared, ::DefaultStoppingCriterionState
)
return state.iteration > 0 && isapprox(state.iterate^2, problem.S; atol = stopping_criterion.tol)
end
```

Note that we automatically obtain a `DefaultStoppingCriterionState` as the final argument, in which we have to store the iteration at which convergence is reached.
As this is a mutating operation that alters the `stopping_criterion_state`, we ensure that it is called exactly once per iteration, while the non-mutating version is simply used to inspect the current status.

```@example Heron
function AlgorithmsInterface.is_finished!(
problem::SqrtProblem, ::Algorithm, state::State,
stopping_criterion::StopWhenSquared, stopping_criterion_state::DefaultStoppingCriterionState
)
if state.iteration > 0 && isapprox(state.iterate^2, problem.S; atol = criterion.tol)
stopping_criterion_state.at_iteration = state.iteration
return true
else
return false
end
end
```

### Reason and convergence reporting

Finally, we need to implement [`get_reason`](@ref) and [`indicates_convergence`](@ref).
These helper functions are required to interact with the [logging system](@ref sec_logging), to distinguish between states that are considered ongoing, stopped and converged, or stopped without convergence.

```@example Heron
function AlgorithmsInterface.get_reason(stopping_criterion::StopWhenSquared, stopping_criterion_state::DefaultStoppingCriterionState)
stopping_criterion_state.at_iteration >= 0 || return nothing
return "The algorithm reached a square root after $(stopping_criterion_state.at_iteration) iterations up to a tolerance of $(stopping_criterion.tol)."
end

AlgorithmsInterface.indicates_convergence(::StopWhenSquared, ::DefaultStoppingCriterionState) = true
```

### Convergence in action

Then we are finally ready to test out our new stopping criteria.

```@example Heron
criterion = StopWhenSquared(1e-8)
heron_sqrt(16.0; stopping_criterion = criterion)
```

### Initialization

Now suppose we want to stop when successive iterates change by less than `ϵ`.
This can be achieved by introducing a new stopping criterion again, but now we have to retain the previous `iterate` in order to have something to compare against.
Similar to the algorithm `State`, we split up the data into a _static_ part, the [`StoppingCriterion`](@ref), and a _dynamic_ part, the [`StoppingCriterionState`](@ref).

```@example Heron
struct StopWhenStable <: StoppingCriterion
Expand All @@ -161,40 +227,54 @@ mutable struct StopWhenStableState <: StoppingCriterionState
end
```

Note that our mutable state holds both the `previous_iterate`, which we need to compare to,
as well as the iteration at which the condition was satisfied.
Note that our mutable state holds both the `previous_iterate`, which we need to compare to, as well as the iteration at which the condition was satisfied.
This is not strictly necessary, but can be convenient to have a persistent indication that convergence was reached.

### Initialization

In order to support these _stateful_ criteria, again an initialization phase is needed.
The relevant functions are now:

- [`AlgorithmsInterface.initialize_stopping_state`](@ref)
- [`AlgorithmsInterface.initialize_stopping_state!`](@ref)

This could be implemented as follows:

```@example Heron
function AlgorithmsInterface.initialize_state(::Problem, ::Algorithm, c::StopWhenStable; kwargs...)
function AlgorithmsInterface.initialize_stopping_state(
::Problem, ::Algorithm,
stopping_criterion::StopWhenStable;
kwargs...
)
return StopWhenStableState(NaN, -1, NaN)
end

function AlgorithmsInterface.initialize_state!(
::Problem, ::Algorithm, stop_when::StopWhenStable, st::StopWhenStableState;
function AlgorithmsInterface.initialize_stopping_state!(
::Problem, ::Algorithm, ::State,
stopping_criterion::StopWhenStable,
stopping_criterion_state::StopWhenStableState;
kwargs...
)
st.previous_iterate = NaN
st.at_iteration = -1
st.delta = NaN
return st
)
stopping_criterion_state.previous_iterate = NaN
stopping_criterion_state.at_iteration = -1
stopping_criterion_state.delta = NaN
return stopping_criterion_state
end
```

### Checking for convergence
!!! note

Then, we need to implement the logic that checks whether an algorithm has finished, which is achieved through [`is_finished`](@ref) and [`is_finished!`](@ref).
Here, the mutating version alters the `stopping_criterion_state`, and should therefore be called exactly once per iteration, while the non-mutating version is simply used to inspect the current status.
While for this simple case this does not matter, note that there is a subtle detail associated to the initialization order of the `State` and `StoppingCriterionState` respectively.
For the first initialization, [`AlgorithmsInterface.initialize_stopping_state`](@ref) is called _before_ [`initialize_state`](@ref).
This is required since the `State` encapsulates the `StoppingCriterionState`.
On the other hand, during the solver, the [`AlgorithmsInterface.initialize_stopping_state!`](@ref) is called _before_ [`initialize_state`](@ref).
This can be important for example to ensure that the initialization time of the state is taken into account for the stopping criteria.

The remainder of the implementation follows straightforwardly, where we again take care to only mutate the `stopping_criterion_state` in the mutating `is_finished!` implementation.

```@example Heron
function AlgorithmsInterface.is_finished!(
::Problem, ::Algorithm, state::State, c::StopWhenStable, st::StopWhenStableState
)
)

k = state.iteration
if k == 0
st.previous_iterate = state.iterate
Expand All @@ -213,21 +293,14 @@ end

function AlgorithmsInterface.is_finished(
::Problem, ::Algorithm, state::State, c::StopWhenStable, st::StopWhenStableState
)
)
k = state.iteration
k == 0 && return false

Δ = abs(state.iterate - st.previous_iterate)
return Δ < c.tol
end
```

### Reason and convergence reporting

Finally, we need to implement [`get_reason`](@ref) and [`indicates_convergence`](@ref).
These helper functions are required to interact with the [logging system](@ref sec_logging), to distinguish between states that are considered ongoing, stopped and converged, or stopped without convergence.

```@example Heron
function AlgorithmsInterface.get_reason(c::StopWhenStable, st::StopWhenStableState)
(st.at_iteration >= 0 && st.delta < c.tol) || return nothing
return "The algorithm reached an approximate stable point after $(st.at_iteration) iterations; the change $(st.delta) is less than $(c.tol)."
Expand All @@ -238,14 +311,14 @@ AlgorithmsInterface.indicates_convergence(c::StopWhenStable, st::StopWhenStableS

### Convergence in action

Then we are finally ready to test out our new stopping criterion.
Again, we can inspect our work:

```@example Heron
criterion = StopWhenStable(1e-8)
heron_sqrt(16.0; stopping_criterion = criterion)
```

Note that our work payed off, as we can still compose this stopping criterion with other criteria as well:
Note that our work to ensure the correct interface payed off, as we can still compose this stopping criterion with other criteria as well:

```@example Heron
criterion = StopWhenStable(1e-8) | StopAfterIteration(5)
Expand All @@ -258,7 +331,7 @@ Implementing a criterion usually means defining:

1. A subtype of [`StoppingCriterion`](@ref).
2. A state subtype of [`StoppingCriterionState`](@ref) capturing dynamic fields.
3. `initialize_state` and `initialize_state!` for setup/reset.
3. `initialize_stopping_state` and `initialize_stopping_state!` for setup/reset.
4. `is_finished!` (mutating) and optionally `is_finished` (non‑mutating) variants.
5. `get_reason` (return `nothing` or a string) for user feedback.
6. `indicates_convergence(::YourCriterion)` to mark if meeting it implies convergence.
Expand Down
1 change: 1 addition & 0 deletions src/AlgorithmsInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using ScopedValues
include("interface/algorithm.jl")
include("interface/problem.jl")
include("interface/state.jl")
include("interface/stopping.jl")
include("interface/interface.jl")

include("stopping_criterion.jl")
Expand Down
Loading
Loading