Problem Intro

You have an untrained neural net that spits a gravity scalar potential \( \mathbb{U} \) in response to a 3D cartesian position

The goal is to figure out gravitational acceleration from the gravity scalar potential

Turns out the gravitational acceleration \( \mathbf{a}(\mathbf{x}) \) is obtained as the negative gradient of the potential:

\[ \mathbf{a}(\mathbf{x}) = -\nabla U(\mathbf{x}) \]


At some point we want to train the neural net, so inevitably we'll get a dataset of 3d cartesian products mapped to their respective acceleration values

But what does the loss function to train the NN look like? For this we take a look at the loss function described in this paper: Physics-informed neural networks for gravity field modeling of small bodies

The loss in question is here:

\[ \mathcal{L}(\theta) = \underbrace{\frac{1}{N_f} \sum_{i=1}^{N_f} \left\| \mathbf{a_i}_{\text{true}} - (- \nabla U(\mathbf{x}_{i})) \right\|^2}_{\text{acceleration error}} + \underbrace{\frac{1}{N_f} \sum_{i=1}^{N_f} \left\| \nabla^2 U(\mathbf{x}_i) \right\|^2}_{\text{Laplacian penalty}} + \underbrace{\frac{1}{N_f} \sum_{i=1}^{N_f} \left\| \nabla \times \nabla U(\mathbf{x}_i) \right\|^2}_{\text{curl penalty}} \]

The first term probably makes sense given it is your basic MSE with the true acceleration values. As for the rest ...

The paper's contribution is that since gravity is a conservative force, the loss function could be significantly enhanced by additional dynamic properties

The scalar potential learned by the network must also obey these additional physics properties

  1. \( \nabla^2 U = 0 \)
  2. \( \nabla \times \mathbf{a} = 0 \)

So the cool bit is that by including these terms, we ensure the model being learnt is more in line with the physics of things which in this case makes sure that force of gravity being modelled in inline with gravity being a conservative force and the physics that comes along with that fact

That's all good and dandy but how do we put this in code?

We'll be looking at implementing this in Jax

The MSE part is straightward but it's coding in the additional constraints that is not exactly straightforward as will be shown. This blog post is about figuring out how to encode the additional contraints related to acceleration field in Jax


Brief tangent

Hang on, why bother with a NN spitting out a scalar potential when it can spit 3 ripe values for acceleration ?

"Most obvious is the fact that the network is now trained with the knowledge that there is a relationship between the accelerations it produces via automatic differentiation and the more fundamental scalar potential. Second, because the network is learning a representation of the potential rather than the accelerations, all three acceleration vector components of the training data are now being used to constrain a single scalar function. This is a much more efficient regression for the network, learning a single potential rather than being forced to learn three separate acceleration features"

Something about this claim (i.e 1 is more efficient than 3) should get your spidey sense tingling. I do have some thoughts on this as I think it is a weak claim and I promise to address this in a future blog post


Calculating the Laplacian of Scalar Potential and the Curl of the acceleration field

There are no jax primitives to calculate these directly To get going, a refresher in Vector Calculus is order

Gradient

A vector valued function that takes in a scalar valued (differential) function (of several variables!)

A more precise way of expressing the domain and range of the function \( f \) and the gradient of the function \( \nabla f \) is:

\[ f: \mathbb{R}^n \rightarrow \mathbb{R}, \quad \nabla f: \mathbb{R}^n \rightarrow \mathbb{R}^n \]

And now how the gradient is computed at a point \[ p = ( x, y, z) \] (Note: we're using just 3 variables but that is not a constraint in general)

\[ \nabla f(p) = \frac{\partial f(p)}{\partial x}\mathbf{i} + \frac{\partial f(p)}{\partial y}\mathbf{j} + \frac{\partial f(p)}{\partial z}\mathbf{k} \]


Jacobian: Generalization of the Gradient

While the input to a gradient is a scalar function, and some point it makes to ask if there is an equivalent if the input is a vector valued function

For a function \( \mathbf{F} : \mathbb{R}^n \rightarrow \mathbb{R}^m \), the Jacobian is an \( m \times n \) matrix whose \( (i, j) \)-th entry is \( \partial F_i / \partial x_j \)

To make things more concrete, let's say we have a function \( \mathbf{F} : \mathbb{R}^3 \rightarrow \mathbb{R}^3 \) and takes in \( \mathbf{p} = ( x, y, z) \in \mathbb{R}^{3} \) as input and produces \( (f_{0}(\mathbf{p}), f_{1}(\mathbf{p}), f_{2}(\mathbf{p})) \in \mathbb{R}^{3} \)

The Jacobian is:

\[ J_{\mathbf{F}}(\mathbf{p}) = \begin{bmatrix} \frac{\partial f_0}{\partial x} & \frac{\partial f_0}{\partial y} & \frac{\partial f_0}{\partial z} \\ \frac{\partial f_1}{\partial x} & \frac{\partial f_1}{\partial y} & \frac{\partial f_1}{\partial z} \\ \frac{\partial f_2}{\partial x} & \frac{\partial f_2}{\partial y} & \frac{\partial f_2}{\partial z} \end{bmatrix} \]


Divergence

While the gradient maps a scalar function to a vector field (a vector field pretty much a vector valued function [there's a bit more to this]) , and the Jacobian maps a vector function to a matrix field (pretty much a matrix valued function), the divergence is a scalar-valued function applied to a vector field.

Computationally that means vector in scalar out

For a vector field \( \mathbf{F} : \mathbb{R}^n \rightarrow \mathbb{R}^n \), the divergence is defined as the dot product of the del operator with the vector field:

\[ \nabla \cdot \mathbf{F} = \sum_{i=0}^{n-1} \frac{\partial F_i}{\partial x_i} \]

More concretely, if \( \mathbf{F}(x, y, z) = (F_0, F_1, F_2) \), then:

\[ \nabla \cdot \mathbf{F} = \frac{\partial F_0}{\partial x} + \frac{\partial F_1}{\partial y} + \frac{\partial F_2}{\partial z} \]


Curl

Something in the same line as the divergence is that in that it operates on a vector field. Something interesting to note about this operator is doesn't generalize beyond 3D space. The mechanical reasoning for this is evident if you see how the curl of a vector field is calculated

For a vector field \( \mathbf{F}(x, y, z) = (F_0, F_1, F_2) \), the curl is defined as the cross product of the del operator with \( \mathbf{F} \):

\[ \nabla \times \mathbf{F} = \begin{vmatrix} \mathbf{i} & \mathbf{j} & \mathbf{k} \\ \frac{\partial}{\partial x} & \frac{\partial}{\partial y} & \frac{\partial}{\partial z} \\ F_0 & F_1 & F_2 \end{vmatrix} \]

Expanded out, this becomes:

\[ \nabla \times \mathbf{F} = \left( \frac{\partial F_2}{\partial y} - \frac{\partial F_1}{\partial z}\right) \mathbf{i} + \left( \frac{\partial F_0}{\partial z} - \frac{\partial F_2}{\partial x}\right)\mathbf{j} + \left(\frac{\partial F_1}{\partial x} - \frac{\partial F_0}{\partial y} \right)\mathbf{k} \]

Curiously the result is a new vector field representing the axis and intensity of local rotation of the input vector field.

Laplacian

(Parroting the wikipedia article on the same) The Laplacian is a differential operator that can be thought of as the divergence of the gradient of a scalar valued function.

For a scalar-valued function \( f : \mathbb{R}^n \rightarrow \mathbb{R} \), the Laplacian is defined as:

\[ \Delta f = \nabla \cdot \nabla f = \nabla^2 f = \sum_{i=0}^{n-1} \frac{\partial^2 f}{\partial x_i^2} \]


Hessian

The mother lode when it comes to our computation of everything we need (for) and this will be made more evident shortly.

The Hessian is a square matrix of second-order partial derivatives of a scalar-valued function

To put it another way, the Hessian of a function \( \mathbf{f} \) is the Jacobian matrix of the gradient of the function \( \mathbf{f} \) that is: \( \mathbf{H}(f(\mathbf{x})) = \mathbf{J}( \nabla f(\mathbf{x})) \).

For a scalar function \( f : \mathbb{R}^n \rightarrow \mathbb{R} \), the Hessian matrix is:

\[ H_f(x) = \left[ \frac{\partial^2 f}{\partial x_i \partial x_j} \right]_{i,j=1}^{n} \]

So if \( f \) is a function of three variables i.e. \( f(x, y, z) \), then:

\[ H_f(x, y, z) = \begin{bmatrix} \frac{\partial^2 f}{\partial x^2} & \frac{\partial^2 f}{\partial x \partial y} & \frac{\partial^2 f}{\partial x \partial z} \\ \frac{\partial^2 f}{\partial y \partial x} & \frac{\partial^2 f}{\partial y^2} & \frac{\partial^2 f}{\partial y \partial z} \\ \frac{\partial^2 f}{\partial z \partial x} & \frac{\partial^2 f}{\partial z \partial y} & \frac{\partial^2 f}{\partial z^2} \end{bmatrix} \]

Jax primitives to calculate these:


Closer look at operators that have no direct Jax routes

When we take a closer look at the laplacian and keep staring it for a bit, we see that it is the just the diagonal of the Hessian.

\[ H_f(x, y, z) = \begin{bmatrix} \colorbox{yellow}{$\displaystyle \frac{\partial^2 f}{\partial x^2}$} & \frac{\partial^2 f}{\partial x \partial y} & \frac{\partial^2 f}{\partial x \partial z} \\ \frac{\partial^2 f}{\partial y \partial x} & \colorbox{yellow}{$\displaystyle \frac{\partial^2 f}{\partial y^2}$} & \frac{\partial^2 f}{\partial y \partial z} \\ \frac{\partial^2 f}{\partial z \partial x} & \frac{\partial^2 f}{\partial z \partial y} & \colorbox{yellow}{$\displaystyle \frac{\partial^2 f}{\partial z^2}$} \end{bmatrix} \]

So what? well given that we have no native jax operation for such, this makes our life easier so all we need to do to calculate the Laplacian or just the divergence of the scalar potential (for our loss function) is:


import jax
import jax.numpy

scalar_potential = model(x)

hessian = jnp.hessian(scalar_potential)
# or you could calculate the hessian with the grad 
# i.e jnp.jacobian(jnp.grad(scalar_potential))

laplacian = jnp.trace(hessian)
		  

What about the curl of the acceleration field? For this particular case since we are allowed to start with the scalar potential, it's the hessian to the rescue again

What do I mean? Let's look at the formula for curl of acceleration

\[ \nabla \times \mathbf{F} = \begin{vmatrix} \mathbf{i} & \mathbf{j} & \mathbf{k} \\ \frac{\partial}{\partial x} & \frac{\partial}{\partial y} & \frac{\partial}{\partial z} \\ A_x & A_y & A_z \end{vmatrix} \]

Remember \( A_x, A_y, A_z \) are the three base components of the acceleration which is nothing but

\[ A_x \mathbf{i} + A_y \mathbf{j} + A_z \mathbf{k} = \frac{\partial s}{\partial x} \mathbf{i} + \frac{\partial s}{\partial y} \mathbf{j} + \frac{\partial s}{\partial z} \mathbf{k} \]

so:

\[ \nabla \times \nabla \mathbf{s} = \begin{vmatrix} \mathbf{i} & \mathbf{j} & \mathbf{k} \\ \frac{\partial}{\partial x} & \frac{\partial}{\partial y} & \frac{\partial}{\partial z} \\ \frac{\partial s}{\partial x} & \frac{\partial s}{\partial y} & \frac{\partial s}{\partial z} \end{vmatrix} \]

\[ \nabla \times \nabla \mathbf{s} = \left( \frac{\partial^2 s}{\partial y \partial z} - \frac{\partial^2 s}{\partial z \partial y}\right) \mathbf{i} + \left( \frac{\partial^2 s}{\partial z \partial x} - \frac{\partial^2 s}{\partial x \partial z}\right)\mathbf{j} + \left(\frac{\partial^2 s}{\partial x \partial y} - \frac{\partial^2 s}{\partial y \partial x} \right)\mathbf{k} \]

You dear reader probably have seen enough, but please be so kind as to allow me to drive home the point

\[ H_f(x, y, z) = \begin{bmatrix} \frac{\partial^2 f}{\partial x^2} & \colorbox{yellow}{$\displaystyle \frac{\partial^2 f}{\partial x \partial y} $} & \colorbox{pink}{$\displaystyle \frac{\partial^2 f}{\partial x \partial z} $} \\ \colorbox{yellow}{$\displaystyle \frac{\partial^2 f}{\partial y \partial x} $} & \frac{\partial^2 f}{\partial y^2} & \colorbox{green}{$\displaystyle \frac{\partial^2 f}{\partial y \partial z} $} \\ \colorbox{pink}{$\displaystyle \frac{\partial^2 f}{\partial z \partial x} $} & \colorbox{green}{$\displaystyle \frac{\partial^2 f}{\partial z \partial y} $} & \frac{\partial^2 f}{\partial z^2} \end{bmatrix} \]

Note: Shouldn't this always be equal to zero? Turns out not always but it should be equal to zero as a constraint when search the space of all potentials


import jax
import jax.numpy

scalar_potential = model(x)

hessian = jnp.hessian(scalar_potential)
# or you could calculate the hessian with the grad 
# i.e jnp.jacobian(jnp.grad(scalar_potential))

curl_x = jnp.array([hessian[2,1] - hessian[1,2]])
curl_y = jnp.array([hessian[2,0] - hessian[0,2]])
curl_z = jnp.array([hessian[1,0] - hessian[0,1]])
curl_of_s = jnp.stack((curl_x, curl_y, curl_z))
		  

So all of this is cool, but does it really work? Let's put this to the test on a analytical scalar potential field defined on a point mass

Which is \( U(\mathbf{r}) = \frac{1}{\|\mathbf{r}\|} = \frac{1}{\sqrt{x^2 + y^2 + z^2}} \) (ignoring the negative sign and any constants)

Calculating the divergence and curl of the acceleration field derived from this should give us close to 0 for pretty much all coordinates not equal to zero


import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

f = lambda x: 1./jnp.linalg.norm(x) # you could multiply this with a high constant like 1000, it still be close to zero

acceleration_field = jax.grad(f)

def curl(acceleration_field):
    hessian = jax.jacobian(acceleration_field)

    def f(x):
        hessian_ = hessian(x)
        curl_x = jnp.array([hessian_[2,1] - hessian_[1,2]])
        curl_y = jnp.array([hessian_[2,0] - hessian_[0,2]])
        curl_z = jnp.array([hessian_[1,0] - hessian_[0,1]])
        return jnp.stack([curl_x, curl_y, curl_z], axis=-1)
    return f

def divergence(acceleration_field):
    hessian = jax.jacobian(acceleration_field)

    def f(x):
        return jnp.trace(hessian(x))
    return f

v_curl = jax.vmap(curl(acceleration_field))
v_divergence = jax.vmap(divergence(acceleration_field))

points = jnp.linspace(-2.0, 2.0, 10)
X,Y,Z = jnp.meshgrid(points, points, points)
X,Y,Z = X.reshape(-1), Y.reshape(-1), Z.reshape(-1)
coords = jnp.stack([X,Y,Z], axis=-1)

# we want to avoid anything too close to zero
filtered_points = coords[jnp.nonzero(jnp.linalg.norm(coords, axis=-1) > 0.1)]

curls = jnp.linalg.norm(v_curl(filtered_points), axis=-1)
divs = jnp.abs(v_divergence(filtered_points))

plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
plt.hist(curls, bins=30, color='skyblue')
plt.title("Curl Magnitude Distribution")
plt.xlabel("curl magnitude")
plt.ylabel("frequency")

plt.subplot(1, 2, 2)
plt.hist(divs, bins=30, color='salmon')
plt.title("Divergence Magnitude Distribution")
plt.xlabel("divergence magnitude")
plt.ylabel("frequency")

plt.tight_layout()
plt.savefig("curl_and_div.png")
	   

Seems like we have the magnitudes of the curl and diverge of the the acceleration field associated with our analytical potential function, very very close to zero providing strong evidence that our computation is right

Conclusion

Given that we now know how to compute the curl and divergence in Jax, encoding the constraints for the conservative force should be a lot more straightforward