diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 5a90faa..39b322b 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -171,19 +171,32 @@ function _forward_eval( end elseif node.index == 3 # :* # Node `k` is not scalar, so we do matrix multiplication + # (or scalar `*` matrix scaling when one operand is scalar). if f.sizes.ndims[k] != 0 @assert N == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - out = _view_matrix(f.forward_storage, f.sizes, k) - LinearAlgebra.mul!(out, v1, v2) + out = _view_linear(f.forward_storage, f.sizes, k) + if f.sizes.ndims[ix1] == 0 + s = _getscalar(f.forward_storage, f.sizes, ix1) + v = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s .* v + elseif f.sizes.ndims[ix2] == 0 + v = _view_linear(f.forward_storage, f.sizes, ix1) + s = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v .* s + else + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + out_m = _view_matrix(f.forward_storage, f.sizes, k) + LinearAlgebra.mul!(out_m, v1, v2) + end # We deliberately don't write v1/v2 into partials_storage - # here: the matmul reverse branch reads forward_storage - # directly, so those writes were dead. + # here: the matmul (or scalar-scaling) reverse branch + # reads forward_storage directly, so those writes were + # dead. # Node `k` is scalar else tmp_prod = one(T) @@ -391,14 +404,29 @@ function _forward_eval( children_indices = SparseArrays.nzrange(f.adj, k) N = length(children_indices) if node.index == 1 # :+ (broadcasted) - for j in _eachindex(f.sizes, k) - tmp_sum = zero(T) - for c_idx in children_indices - ix = children_arr[c_idx] - @j f.partials_storage[ix] = one(T) - tmp_sum += @j f.forward_storage[ix] + # Broadcast-aware sum: scalar children contribute their + # single value to every output slot. + out = _view_linear(f.forward_storage, f.sizes, k) + fill!(out, zero(T)) + for c_idx in children_indices + ix = children_arr[c_idx] + if f.sizes.ndims[ix] == 0 + s = _getscalar(f.forward_storage, f.sizes, ix) + out .+= s + _setscalar!( + f.partials_storage, + one(T), + f.sizes, + ix, + ) + else + v = _view_linear(f.forward_storage, f.sizes, ix) + out .+= v + fill!( + _view_linear(f.partials_storage, f.sizes, ix), + one(T), + ) end - @j f.forward_storage[k] = tmp_sum end elseif node.index == 2 # :- (broadcasted) @assert N == 2 @@ -406,31 +434,88 @@ function _forward_eval( @inbounds ix1 = children_arr[child1] @inbounds ix2 = children_arr[child1+1] out = _view_linear(f.forward_storage, f.sizes, k) - v1 = _view_linear(f.forward_storage, f.sizes, ix1) - v2 = _view_linear(f.forward_storage, f.sizes, ix2) - out .= v1 .- v2 - fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T)) - fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T)) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s1 .- v2 + _setscalar!(f.partials_storage, one(T), f.sizes, ix1) + fill!( + _view_linear(f.partials_storage, f.sizes, ix2), + -one(T), + ) + elseif ndims1 != 0 && ndims2 == 0 + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v1 .- s2 + fill!( + _view_linear(f.partials_storage, f.sizes, ix1), + one(T), + ) + _setscalar!(f.partials_storage, -one(T), f.sizes, ix2) + else + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + out .= v1 .- v2 + fill!( + _view_linear(f.partials_storage, f.sizes, ix1), + one(T), + ) + fill!( + _view_linear(f.partials_storage, f.sizes, ix2), + -one(T), + ) + end elseif node.index == 3 # :* (broadcasted) - # Node `k` is not scalar, so we do matrix multiplication + # Node `k` is not scalar, so we do element-wise multiply + # (with scalar-broadcast support: when one operand is + # scalar, broadcast it across the matrix output). if f.sizes.ndims[k] != 0 @assert N == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = zeros(_size(f.sizes, ix1)...) - v2 = zeros(_size(f.sizes, ix2)...) - for j in _eachindex(f.sizes, ix1) - v1[j] = @j f.forward_storage[ix1] - @j f.partials_storage[ix2] = v1[j] - end - for j in _eachindex(f.sizes, ix2) - v2[j] = @j f.forward_storage[ix2] - @j f.partials_storage[ix1] = v2[j] - end - for j in _eachindex(f.sizes, k) - @j f.forward_storage[k] = v1[j] * v2[j] + out = _view_linear(f.forward_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + s = _getscalar(f.forward_storage, f.sizes, ix1) + v = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s .* v + # Per-element partial w.r.t. the matrix child is + # the scalar; the scalar child's reverse is handled + # by the broadcasted-`:*` reverse branch below + # (sum of `rev_parent .* v`). + fill!( + _view_linear(f.partials_storage, f.sizes, ix2), + s, + ) + elseif ndims1 != 0 && ndims2 == 0 + v = _view_linear(f.forward_storage, f.sizes, ix1) + s = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v .* s + fill!( + _view_linear(f.partials_storage, f.sizes, ix1), + s, + ) + else + # Both children are arrays of the same shape — + # original element-wise path. + v1 = zeros(_size(f.sizes, ix1)...) + v2 = zeros(_size(f.sizes, ix2)...) + for j in _eachindex(f.sizes, ix1) + v1[j] = @j f.forward_storage[ix1] + @j f.partials_storage[ix2] = v1[j] + end + for j in _eachindex(f.sizes, ix2) + v2[j] = @j f.forward_storage[ix2] + @j f.partials_storage[ix1] = v2[j] + end + for j in _eachindex(f.sizes, k) + @j f.forward_storage[k] = v1[j] * v2[j] + end end # Node `k` is scalar else @@ -620,23 +705,54 @@ function _reverse_eval( op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] if op == :* if f.sizes.ndims[k] != 0 - # Matrix multiplication: rev_v1 = rev_parent * v2', - # rev_v2 = v1' * rev_parent. Both v1 and v2 are read - # straight from forward_storage (the matmul forward - # branch deliberately doesn't snapshot them into - # partials_storage), and the reverse views are written - # in place. + # Matmul (or `scalar * matrix` scaling): rev_v1 = + # rev_parent * v2', rev_v2 = v1' * rev_parent. With + # a scalar operand, the result is `s .* M`, so + # rev[s] = sum(rev_parent .* M) and rev[M] = + # rev_parent .* s. Both v1 and v2 are read straight + # from forward_storage. idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - rev_parent = _view_matrix(f.reverse_storage, f.sizes, k) - rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1) - rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2) - LinearAlgebra.mul!(rev_v1, rev_parent, v2') - LinearAlgebra.mul!(rev_v2, v1', rev_parent) + rev_parent = + _view_linear(f.reverse_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + rev_v2 = _view_linear(f.reverse_storage, f.sizes, ix2) + rev_v2 .= rev_parent .* s1 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v2), + f.sizes, + ix1, + ) + elseif ndims1 != 0 && ndims2 == 0 + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1) + rev_v1 .= rev_parent .* s2 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v1), + f.sizes, + ix2, + ) + else + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + rev_parent_m = + _view_matrix(f.reverse_storage, f.sizes, k) + rev_v1 = + _view_matrix(f.reverse_storage, f.sizes, ix1) + rev_v2 = + _view_matrix(f.reverse_storage, f.sizes, ix2) + LinearAlgebra.mul!(rev_v1, rev_parent_m, v2') + LinearAlgebra.mul!(rev_v2, v1', rev_parent_m) + end continue end elseif op == :vect @@ -832,13 +948,82 @@ function _reverse_eval( elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS) op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + # Broadcasted +/- with at least one scalar child: the + # scalar's reverse is the (signed) sum of the parent's + # adjoint over the broadcast positions. Handle both scalar + # and matrix children here so the generic + # diagonal-partial path below doesn't trip its + # `_size(k) == _size(ix)` assertion. + if (op == :+ || op == :-) && any( + c -> f.sizes.ndims[children_arr[c]] == 0, + children_indices, + ) && f.sizes.ndims[k] != 0 + Tr = eltype(f.reverse_storage) + rev_parent = + _view_linear(f.reverse_storage, f.sizes, k) + for c_idx in children_indices + ix = children_arr[c_idx] + # `:-` flips the sign for the second operand, mirroring + # the partial we wrote in the forward pass. + partial_sign = + (op == :- && c_idx != first(children_indices)) ? + -one(Tr) : one(Tr) + if f.sizes.ndims[ix] == 0 + _setscalar!( + f.reverse_storage, + partial_sign * sum(rev_parent), + f.sizes, + ix, + ) + else + rev_child = + _view_linear(f.reverse_storage, f.sizes, ix) + rev_child .= partial_sign .* rev_parent + end + end + continue + end if op == :* if f.sizes.ndims[k] != 0 - # Node `k` is not scalar, so we do matrix multiplication or broadcasted multiplication idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] + rev_parent = + _view_linear(f.reverse_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + v2 = + _view_linear(f.forward_storage, f.sizes, ix2) + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + rev_v2 = + _view_linear(f.reverse_storage, f.sizes, ix2) + rev_v2 .= rev_parent .* s1 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v2), + f.sizes, + ix1, + ) + continue + elseif ndims1 != 0 && ndims2 == 0 + v1 = + _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + rev_v1 = + _view_linear(f.reverse_storage, f.sizes, ix1) + rev_v1 .= rev_parent .* s2 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v1), + f.sizes, + ix2, + ) + continue + end + # Both children are arrays of the same shape — + # original element-wise path. v1 = zeros(_size(f.sizes, ix1)...) v2 = zeros(_size(f.sizes, ix2)...) for j in _eachindex(f.sizes, ix1) @@ -847,18 +1032,18 @@ function _reverse_eval( for j in _eachindex(f.sizes, ix2) v2[j] = @j f.forward_storage[ix2] end - rev_parent = zeros(_size(f.sizes, k)...) + rev_parent_arr = zeros(_size(f.sizes, k)...) for j in _eachindex(f.sizes, k) - rev_parent[j] = @j f.reverse_storage[k] + rev_parent_arr[j] = @j f.reverse_storage[k] end rev_v1 = zeros(_size(f.sizes, ix1)...) rev_v2 = zeros(_size(f.sizes, ix2)...) for j in _eachindex(f.sizes, ix1) - rev_v1[j] = rev_parent[j] * v2[j] + rev_v1[j] = rev_parent_arr[j] * v2[j] @j f.reverse_storage[ix1] = rev_v1[j] end for j in _eachindex(f.sizes, ix2) - rev_v2[j] = rev_parent[j] * v1[j] + rev_v2[j] = rev_parent_arr[j] * v1[j] @j f.reverse_storage[ix2] = rev_v2[j] end continue diff --git a/src/sizes.jl b/src/sizes.jl index 7fc3c90..0af3c0d 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -309,8 +309,17 @@ function _infer_sizes( return !iszero(sizes.ndims[children_arr[i]]) end if !isnothing(first_matrix) - if sizes.ndims[children_arr[first(children_indices)]] == 0 - _add_size!(sizes, k, (1, 1)) + first_is_scalar = + sizes.ndims[children_arr[first(children_indices)]] == 0 + last_is_scalar = + sizes.ndims[children_arr[last(children_indices)]] == 0 + if first_is_scalar || last_is_scalar + # `scalar * matrix` (or `matrix * scalar`) is + # element-wise scaling, not matmul: result inherits + # the matrix's shape. + ix_mat = + children_arr[children_indices[first_matrix]] + _copy_size!(sizes, k, ix_mat) continue else _add_size!( @@ -357,47 +366,29 @@ function _infer_sizes( continue end op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :+ || op == :- - # Broadcasted +/- preserves shape - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :^ - # Broadcasted ^ with scalar exponent preserves base shape - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :* - # TODO assert compatible sizes and all ndims should be 0 or 2 + if op == :+ || op == :- || op == :* + # Broadcasted +/-/* takes the largest child's shape (for + # scalar+matrix that's the matrix; for matrix+matrix we + # currently assume they match and pick the first). first_matrix = findfirst(children_indices) do i return !iszero(sizes.ndims[children_arr[i]]) end - if !isnothing(first_matrix) - if sizes.ndims[children_arr[first(children_indices)]] == 0 - _add_size!(sizes, k, (1, 1)) - continue - else - if sizes.ndims[children_arr[first(children_indices)]] == - 1 - nb_cols = 1 - else - nb_cols = _size( - sizes, - children_arr[first(children_indices)], - 1, - ) - end - _add_size!( - sizes, - k, - ( - _size( - sizes, - children_arr[first(children_indices)], - 1, - ), - nb_cols, - ), - ) - continue - end + if isnothing(first_matrix) + _copy_size!( + sizes, + k, + children_arr[first(children_indices)], + ) + else + _copy_size!( + sizes, + k, + children_arr[children_indices[first_matrix]], + ) end + elseif op == :^ + # Broadcasted ^ with scalar exponent preserves base shape + _copy_size!(sizes, k, children_arr[first(children_indices)]) end elseif node.type == NODE_CALL_UNIVARIATE if !( diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 1a888e7..f56f7b8 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -558,9 +558,9 @@ function test_objective_broadcasted_product() evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4]) MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes - @test sizes.ndims == [0, 2, 1, 0, 0, 1, 0, 0] + @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0] @test sizes.size_offset == [0, 2, 1, 0, 0, 0, 0, 0] - @test sizes.size == [2, 2, 2, 1] + @test sizes.size == [2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11] x1 = 1.0 x2 = 2.0 diff --git a/test/JuMP.jl b/test/JuMP.jl index 9c44143..4923efc 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -7,6 +7,8 @@ using ArrayDiff import LinearAlgebra import MathOptInterface as MOI +include("Transformer.jl") + function runtests() for name in names(@__MODULE__; all = true) if startswith("$(name)", "test_") @@ -387,6 +389,63 @@ function test_moi_function() return end +# Plug JuMP variable matrices into the Transformer's `MLP` building block +# (`gelu(x * c_fc) * c_proj`) and confirm the forward+reverse pass runs +# end-to-end through the ArrayDiff evaluator. `gelu` exercises every +# scalar-broadcast pattern that ArrayDiff supports for `MatrixExpr`: +# `Number * matrix` scaling, `Number .* matrix`, and `Number .+ matrix`. +# We finite-difference the analytic gradient as a sanity check. +function test_transformer_mlp_gradient() + d_emb, d_hidden, seq = 2, 3, 2 + model = Model() + @variable( + model, + c_fc[1:d_emb, 1:d_hidden], + container = ArrayDiff.ArrayOfVariables, + ) + @variable( + model, + c_proj[1:d_hidden, 1:d_emb], + container = ArrayDiff.ArrayOfVariables, + ) + mlp = MLP(c_fc, c_proj) + x = rand(seq, d_emb) + loss = sum(mlp(x) .^ 2) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + nvar = JuMP.num_variables(model) + @test nvar == 2 * d_emb * d_hidden + x_pt = randn(nvar) + val = MOI.eval_objective(evaluator, x_pt) + @test isfinite(val) + @test val >= 0 + g = zeros(nvar) + MOI.eval_objective_gradient(evaluator, g, x_pt) + @test all(isfinite, g) + @test !all(iszero, g) + # Central finite differences on the AD-built objective. + h = 1e-6 + g_fd = zeros(nvar) + for i in 1:nvar + xp = copy(x_pt) + xp[i] += h + xm = copy(x_pt) + xm[i] -= h + g_fd[i] = + (MOI.eval_objective(evaluator, xp) - + MOI.eval_objective(evaluator, xm)) / (2h) + end + @test isapprox(g, g_fd; rtol = 1e-4) + return +end + end # module TestJuMP.runtests() diff --git a/test/Transformer.jl b/test/Transformer.jl new file mode 100644 index 0000000..783c8f2 --- /dev/null +++ b/test/Transformer.jl @@ -0,0 +1,158 @@ +# A minimal GPT-style Transformer implementation in pure Julia +# Inpired from https://github.com/karpathy/nanogpt + +using Random +using LinearAlgebra + +# Helper functions +function gelu(x) + return 0.5 * x .* (1 .+ tanh.(sqrt(2 / π) .* (x .+ 0.044715 .* x.^3))) +end + +# LayerNorm +struct LayerNorm{V} + γ::V + β::V + ϵ::Float64 +end + +function LayerNorm(dim::Int; ϵ=1e-5) + # We could use `ones(dim)` and thend + # do `γ'` but then we'll need to implement + # `adjoint` for `VectNode` + γ = ones(1, dim) + β = zeros(1, dim) + return LayerNorm(γ, β, ϵ) +end + +function (ln::LayerNorm)(x) + d = size(x, 2) + μ = sum(x, dims=2) / d + σ2 = sum((x .- μ).^2, dims=2) / d + x̂ = (x .- μ) ./ sqrt.(σ2 .+ ln.ϵ) + return ln.γ .* x̂ .+ ln.β +end + +# Causal Self-Attention (single head) +struct CausalSelfAttention{M} + wq::M + wk::M + wv::M +end + +function CausalSelfAttention(d_emb::Int, d_head::Int) + wq = randn(d_emb, d_head) * 0.02 + wk = randn(d_emb, d_head) * 0.02 + wv = randn(d_emb, d_head) * 0.02 + return CausalSelfAttention(wq, wk, wv) +end + +function (attn::CausalSelfAttention)(x) + # x: (seq, d_emb) + q = x * attn.wq + k = x * attn.wk + v = x * attn.wv + + seq = size(x, 1) + d_head = size(attn.wq, 2) + + attn_scores = (q * k') / sqrt(d_head) + # Causal mask + mask = [i < j ? -Inf : Inf for i in 1:seq, j in 1:seq] + attn_scores = min.(attn_scores, mask) + attn_weights = softmax(attn_scores, dims=2) + return attn_weights * v +end + +# Multi-Head Attention +struct MultiHead{M} + heads::Vector{CausalSelfAttention{M}} + wo::M +end + +function MultiHead(d_emb::Int, n_head::Int) + head_dim = div(d_emb, n_head) + heads = [CausalSelfAttention(d_emb, head_dim) for _ in 1:n_head] + wo = randn(n_head * head_dim, d_emb) * 0.02 + return MultiHead(heads, wo) +end + +function (mha::MultiHead)(x) + outs = [head(x) for head in mha.heads] + out = reduce(hcat, outs) # (seq, n_head*head_dim) + return out * mha.wo +end + +# MLP (Feedforward network) +struct MLP{M} + c_fc::M + c_proj::M +end + +function MLP(d_emb::Int, d_hidden::Int) + c_fc = randn(d_emb, d_hidden) * 0.02 + c_proj = randn(d_hidden, d_emb) * 0.02 + return MLP(c_fc, c_proj) +end + +function (mlp::MLP)(x) + return gelu(x * mlp.c_fc) * mlp.c_proj +end + +# Transformer Block +struct Block{V,M} + ln1::LayerNorm{V} + attn::MultiHead + ln2::LayerNorm{V} + mlp::MLP{M} +end + +function Block(d_emb::Int, n_head::Int, n_hidden::Int) + ln1 = LayerNorm(d_emb) + attn = MultiHead(d_emb, n_head) + ln2 = LayerNorm(d_emb) + mlp = MLP(d_emb, n_hidden) + return Block(ln1, attn, ln2, mlp) +end + +function (block::Block)(x) + x = x .+ block.attn(block.ln1(x)) + x = x .+ block.mlp(block.ln2(x)) + return x +end + +# The full Transformer +struct Transformer{V,M} + wte::M # token embedding + wpe::M # position embedding + blocks::Vector{Block{V,M}} + ln_f::LayerNorm{V} + n_voc::Int + d_emb::Int +end + +function Transformer(; n_voc::Int, n_ctx::Int, n_layer::Int, n_head::Int, d_emb::Int, d_ff::Int) + wte = randn(n_voc, d_emb) * 0.02 + wpe = randn(n_ctx, d_emb) * 0.02 + blocks = [Block(d_emb, n_head, d_ff) for _ in 1:n_layer] + ln_f = LayerNorm(d_emb) + return Transformer(wte, wpe, blocks, ln_f, n_voc, d_emb) +end + +function (m::Transformer)(idx) + x = m.wte[idx, :] .+ m.wpe[eachindex(idx), :] + for block in m.blocks + x = block(x) + end + x = m.ln_f(x) + # logits: (seq, n_voc) + logits = x * m.wte' + return logits +end + +# Softmax helper +function softmax(x; dims=1) + x_max = maximum(x, dims=dims) + ex = exp.(x .- x_max) + return ex ./ sum(ex, dims=dims) +end