8000 Feat!: new parameter API by fjebaker · Pull Request #63 · fjebaker/SpectralFitting.jl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Feat!: new parameter API #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 1, 2023
4 changes: 4 additions & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ using DocStringExtensions
abstract type AbstractMission end
struct NoMission <: AbstractMission end

abstract type AbstractStatistic end

# unitful units
include("units.jl")
SpectralUnits.@reexport using .SpectralUnits

include("print-utilities.jl")

include("fitparam.jl")
include("param-cache.jl")
include("abstract-models.jl")

include("ccall-wrapper.jl")
Expand All @@ -66,6 +69,7 @@ include("model-data-io.jl")
include("fitting/result.jl")
include("fitting/cache.jl")
include("fitting/problem.jl")
include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")
Expand Down
107 changes: 59 additions & 48 deletions src/abstract-models.jl
< F438 td class="blob-code blob-code-addition js-file-line"> invokemodel!(output, energy, model, p0)
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ export AbstractSpectralModel,
invokemodel!,
objective_cache_count,
modelparameters,
freeparameters,
frozenparameters,
updateparameters!
updateparameters!,
make_parameter_cache

"""
abstract type AbstractSpectralModel{T,K}
Expand Down Expand Up @@ -57,7 +56,7 @@ abstract type AbstractSpectralModelKind end

Additive models are effectively the sources of photons, and are the principle building blocks
of composite models. Every additive model has a normalisation parameter which re-scales the
flux by a constant factor `K`.
output by a constant factor `K`.

!!! note
Defining custom additive models requires special care. See [Defining new models](@ref).
Expand All @@ -68,15 +67,15 @@ struct Additive <: AbstractSpectralModelKind end
Multiplicative()

Multiplicative models act on [`Additive`](@ref) models, by element-wise
multiplying the flux in each energy bin of the additive model by a different factor.
multiplying the output in each energy bin of the additive model by a different factor.
"""
struct Multiplicative <: AbstractSpectralModelKind end
"""
Convolutional <: AbstractSpectralModelKind
Convolutional()

Convolutional models act on the flux generated by [`Additive`](@ref) models, similar to
[`Multiplicative`](@ref) models, however may convolve kernels through the flux also.
Convolutional models act on the output generated by [`Additive`](@ref) models, similar to
[`Multiplicative`](@ref) models, however may convolve kernels through the output also.
"""
struct Convolutional <: AbstractSpectralModelKind end

Expand Down Expand Up @@ -125,10 +124,10 @@ ConstructionBase.constructorof(::Type{M}) where {M<:AbstractSpectralModel} = M
# never to be called directly
# favour `invokemodel!` instead
"""
SpectralFitting.invoke!(flux, energy, M::Type{<:AbstractSpectralModel}, params...)
SpectralFitting.invoke!(output, energy, M::Type{<:AbstractSpectralModel}, params...)

Used to define the behaviour of models. Should calculate flux of the model and write in-place
into `flux`.
Used to define the behaviour of models. Should calculate output of the model and write in-place
into `output`.

!!! warning
This function should not be called directly. Use [`invokemodel`](@ref) instead.
Expand All @@ -144,15 +143,15 @@ end
```
would have the arguments passed to `invoke!` as
```julia
function SpectralFitting.invoke!(flux, energy, ::Type{<:MyModel}, p1, p2, p3, ...)
function SpectralFitting.invoke!(output, energy, ::Type{<:MyModel}, p1, p2, p3, ...)
# ...
end
```

The only exception to this are [`Additive`](@ref) models, where the normalisation parameter
`K` is not passed to `invoke!`.
"""
invoke!(flux, energy, M::AbstractSpectralModel) = error("Not defined for $(M).")
invoke!(output, energy, M::AbstractSpectralModel) = error("Not defined for $(M).")

"""
invokemodel(energy, model)
Expand All @@ -169,7 +168,7 @@ any normalisation or post-processing tasks that a specific model kind may requir
Users should always call models using [`invokemodel`](@ref) or [`invokemodel!`](@ref) to ensure
normalisations and closures are accounted for.

`invokemodel` allocates the needed flux arrays based on the element type of `free_params` to allow
`invokemodel` allocates the needed output arrays based on the element type of `free_params` to allow
automatic differentation libraries to calculate parameter gradients.

In-place non-allocating variants are the [`invokemodel!`](@ref) functions.
Expand All @@ -186,29 +185,23 @@ invokemodel(energy, model, p0)
```
"""
function invokemodel(e, m::AbstractSpectralModel)
flux = construct_objective_cache(m, e) |> vec
invokemodel!(flux, e, m)
flux
end
function invokemodel(e, m::AbstractSpectralModel, free_params)
model = remake_with_free(m, free_params)
flux = construct_objective_cache(eltype(free_params), m, e) |> vec
invokemodel!(flux, e, model)
flux
output = construct_objective_cache(m, e) |> vec
invokemodel!(output, e, m)
output
end

"""
invokemodel!(flux, energy, model)
invokemodel!(flux, energy, model, free_params)
invokemodel!(flux, energy, model, free_params, frozen_params)
invokemodel!(output, energy, model)
invokemodel!(output, energy, model, free_params)
invokemodel!(output, energy, model, free_params, frozen_params)

In-place variant of [`invokemodel`](@ref), calculating the flux of an [`AbstractSpectralModel`](@ref)
In-place variant of [`invokemodel`](@ref), calculating the output of an [`AbstractSpectralModel`](@ref)
given by `model`, optionally overriding the free and/or frozen parameter values. These arguments
may be a vector or tuple with element type [`FitParam`](@ref) or `Number`.

The number of fluxes to allocate for a model may change if using any [`CompositeModel`](@ref)
as the `model`. It is generally recommended to use [`objective_cache_count`](@ref) to ensure the correct number
of flux arrays are allocated with [`construct_objective_cache`](@ref) when using composite models.
of output arrays are allocated with [`construct_objective_cache`](@ref) when using composite models.

Single spectral model components should use [`make_flux`](@ref) instead.

Expand All @@ -217,23 +210,26 @@ Single spectral model components should use [`make_flux`](@ref) instead.
```julia
model = XS_PowerLaw()
energy = collect(range(0.1, 20.0, 100))
flux = make_flux(model, energy)
invokemodel!(flux, energy, model)
output = make_flux(model, energy)
invokemodel!(output, energy, model)

p0 = [0.1, 2.0] # change K and a
invokemodel!(flux, energy, model, p0)
```
"""
@inline function invokemodel!(f, e, m::AbstractSpectralModel, free_params)
# update only the free parameters
model = remake_with_free(m, free_params)
invokemodel!(view(f, :, 1), e, model)
end

@inline function invokemodel!(f, e, m::AbstractSpectralModel{<:FitParam})
# need to extract the parameter values
model = remake_with_number_type(m)
invokemodel!(view(f, :, 1), e, model)
end
@inline function invokemodel!(f, e, m::AbstractSpectralModel, cache::ParameterCache)
invokemodel!(f, e, m, cache.parameters)
end
@inline function invokemodel!(f, e, m::AbstractSpectralModel, parameters::AbstractArray)
invokemodel!(view(f, :, 1), e, remake_with_parameters(m, parameters))
end

invokemodel!(
f::AbstractVector,
e::AbstractVector,
Expand Down Expand Up @@ -284,32 +280,21 @@ end

modelparameters(model::AbstractSpectralModel{T}) where {T} =
T[model_parameters_tuple(model)...]
freeparameters(model::AbstractSpectralModel{T}) where {T} =
T[free_parameters_tuple(model)...]
frozenparameters(model::AbstractSpectralModel{T}) where {T} =
T[frozen_parameters_tuple(model)...]

# todo: this function could be cleaned up with some generated hackery
function remake_with_number_type(model::AbstractSpectralModel{P}, T::Type) where {P}
M = typeof(model).name.wrapper
params = modelparameters(model)
params = model_parameters_tuple(model)
new_params = if P <: FitParam
convert.(T, get_value.(params))
else
convert.(T, param)
end
M{T,FreeParameters{free_parameter_symbols(model)}}(new_params...)
M{T}(new_params...)
end
remake_with_number_type(model::AbstractSpectralModel{FitParam{T}}) where {T} =
remake_with_number_type(model, T)

remake_with_free(
model::AbstractSpectralModel{T},
free_params::AbstractVector{T},
) where {T<:Number} = updatefree(model, free_params)
remake_with_free(model::AbstractSpectralModel, free_params) =
updatefree(remake_with_number_type(model, eltype(free_params)), free_params)

"""
updatemodel(model::AbstractSpectralModel; kwargs...)
updatemodel(model::AbstractSpectralModel, patch::NamedTuple)
Expand Down Expand Up @@ -357,3 +342,29 @@ function updateparameters!(model::AbstractSpectralModel{<:FitParam}; params...)
end
model
end

_allocate_free_parameters(model::AbstractSpectralModel) =
filter(isfree, modelparameters(model))

function make_parameter_cache(model::AbstractSpectralModel)
parameters = modelparameters(model)
ParameterCache(parameters)
end

function make_diff_parameter_cache(
model::AbstractSpectralModel;
param_diff_cache_size = nothing,
)
parameters = modelparameters(model)
free_mask = _make_free_mask(parameters)

vals = map(get_value, parameters)
N = isnothing(param_diff_cache_size) ? length(vals) : param_diff_cache_size
diffcache = DiffCache(vals, ForwardDiff.pickchunksize(N))

# embed current parameter values inside of the dual cache
# else all frozens will be zero
get_tmp(diffcache, ForwardDiff.Dual(1.0)) .= vals

ParameterCache(free_mask, diffcache)
end
4 changes: 2 additions & 2 deletions src/ccall-wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ If the callsite is not specified, the user must implement [`_unsafe_ffi_invoke!`
# Examples

```julia
@xspecmodel :C_powerlaw struct XS_PowerLaw{T,F} <: AbstractSpectralModel{T, Additive}
@xspecmodel :C_powerlaw struct XS_PowerLaw{T} <: AbstractSpectralModel{T, Additive}
"Normalisation."
K::T
"Photon index."
Expand All @@ -130,7 +130,7 @@ end

# constructor has default values
function XS_PowerLaw(; K = FitParam(1.0), a = FitParam(1.0))
XS_PowerLaw{typeof(K), SpectralFitting.FreeParameters{(:K, :a)}}(K, a)
XS_PowerLaw{typeof(K)}(K, a)
end
```

Expand Down
61 changes: 33 additions & 28 deletions src/composite-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,37 +101,22 @@ function closurekind(::Type{<:CompositeModel{M1,M2}}) where {M1,M2}
end

# invocation wrappers
function invokemodel(e, m::CompositeModel)
fluxes = construct_objective_cache(m, e)
invokemodel!(fluxes, e, m)
view(fluxes, :, 1)
end

function invokemodel!(f, e, model::CompositeModel)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
generated_model_call!(f, e, model, model_parameters_tuple(model))
end
function invokemodel!(f, e, model::CompositeModel, free_params, frozen_params)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
generated_model_call!(f, e, model, free_params, frozen_params)
end
function invokemodel!(f, e, model::CompositeModel, free_params)

function invokemodel!(f, e, model::CompositeModel, parameters::AbstractArray)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
frozen_params = convert.(eltype(free_params), frozenparameters(model))
invokemodel!(f, e, model, free_params, frozen_params)
generated_model_call!(f, e, model, parameters)
end

function invokemodel(e, m::CompositeModel)
fluxes = construct_objective_cache(m, e)
invokemodel!(fluxes, e, m)
view(fluxes, :, 1)
end
function invokemodel(e, m::CompositeModel, free_params)
if eltype(free_params) <: Number
# for compatability with AD
fluxes = construct_objective_cache(eltype(free_params), m, e)
invokemodel!(fluxes, e, m, free_params)
else
p0 = get_value.(free_params)
fluxes = construct_objective_cache(eltype(p0), m, e)
invokemodel!(fluxes, e, m, p0)
end
view(fluxes, :, 1)
end

# algebra grammar
add_models(_, _, ::M1, ::M2) where {M1,M2} =
Expand Down Expand Up @@ -159,7 +144,7 @@ conv_models(m1::M1, m2::M2) where {M1,M2} =

function Base.show(io::IO, @nospecialize(model::CompositeModel))
expr, infos = _destructure_for_printing(model)
for (symbol, (m, _, _)) in zip(keys(infos), infos)
for (symbol, (m, _)) in zip(keys(infos), infos)
expr =
replace(expr, "$(symbol)" => "$(FunctionGeneration.model_base_name(typeof(m)))")
end
Expand Down Expand Up @@ -221,11 +206,11 @@ function _printinfo(io::IO, model::CompositeModel{M1,M2}) where {M1,M2}

println(io, "Model key and parameters:")
sym_buffer = 5
param_name_offset = sym_buffer + maximum(infos) do (_, syms, _)
param_name_offset = sym_buffer + maximum(infos) do (_, syms)
maximum(length(s) for s in syms)
end
buff = IOBuffer()
for (symbol, (m, param_symbols, states)) in zip(keys(infos), infos)
for (symbol, (m, param_symbols)) in zip(keys(infos), infos)
M = typeof(m)
basename = FunctionGeneration.model_base_name(M)
println(
Expand All @@ -239,7 +224,8 @@ function _printinfo(io::IO, model::CompositeModel{M1,M2}) where {M1,M2}
Crayons.Crayon(reset = true),
)

for (val, s::String, free::Bool) in zip(modelparameters(m), param_symbols, states)
for (val, s::String) in zip(modelparameters(m), param_symbols)
free = !isfrozen(val)
_print_param(buff, free, s, val, param_name_offset, q1, q2, q3, q4)
end
end
Expand All @@ -255,5 +241,24 @@ ConstructionBase.setproperties(::CompositeModel, ::NamedTuple) =
ConstructionBase.constructorof(::Type{<:CompositeModel}) =
throw("Cannot be used with `CompositeModel`.")

function Base.propertynames(model::CompositeModel, private::Bool = false)
all_parameter_symbols(model)
end

function Base.getproperty(model::CompositeModel, symb::Symbol)
lookup = all_parameters_to_named_tuple(model)
lookup[symb]
end

function Base.setproperty!(model::CompositeModel, symb::Symbol, value::FitParam)
set!(getproperty(model, symb), value)
end

function Base.setproperty!(model::CompositeModel, symb::Symbol, x)
error(
"Only `FitParam` may be directly set with another `FitParam`. Use `set_value!` and related API otherwise.",
)
end

# function ConstructionBase.setproperties(m::CompositeModel, patch::NamedTuple)
# end
3D11 21 changes: 0 additions & 21 deletions src/datasets/response.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ function normalise_rows!(matrix)
end
end

function build_response_matrix!(
R,
f_chan::Matrix,
n_chan::Matrix,
matrix_rows::Vector,
first_channel,
)
for (i, (F, N)) in enumerate(zip(eachcol(f_chan), eachcol(n_chan)))
M = matrix_rows
index = 1
for (first, len) in zip(F, N)
if len == 0
break
end
first -= first_channel
@views R[first+1:first+len, i] .= M[index:index+len-1]
index += len
end
end
end

function Base.show(
io::IO,
::MIME{Symbol("text/plain")},
Expand Down
Loading
0