
The intersection of many-body physics and deep learning has opened a new frontier: Neural Quantum States (NQS). While traditional methods struggle with high-dimensional frustrated systems, the global attention mechanism of Transformers provides a powerful tool for capturing complex quantum correlations.
In this tutorial, we implement a research-grade Variational Monte Carlo (VMC) pipeline using NetKet and JAX to solve the frustrated J1–J2 Heisenberg spin chain. We will:
- Build a custom Transformer-based NQS architecture.
- Optimize the wavefunction using Stochastic Reconfiguration (natural gradient descent).
- Benchmark our results against exact diagonalization and analyze emergent quantum phases.
By the end of this guide, you will have a scalable, physically grounded simulation framework capable of exploring quantum magnetism beyond the reach of classical exact methods.
!pip -q install --upgrade pip
!pip -q install "netket" "flax" "optax" "einops" "tqdm"
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import netket as nk
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from flax import linen as nn
from tqdm import tqdm
jax.config.update("jax_enable_x64", True)
print("JAX devices:", jax.devices())
def make_j1j2_chain(L, J2, total_sz=0.0):
J1 = 1.0
edges = []
for i in range(L):
edges.append([i, (i+1)%L, 1])
edges.append([i, (i+2)%L, 2])
g = nk.graph.Graph(edges=edges)
hi = nk.hilbert.Spin(s=0.5, N=L, total_sz=total_sz)
sigmaz = np.array([[1,0],[0,-1]], dtype=np.float64)
mszsz = np.kron(sigmaz, sigmaz)
exchange = np.array(
[[0,0,0,0],
[0,0,2,0],
[0,2,0,0],
[0,0,0,0]], dtype=np.float64
)
bond_ops = [
(J1*mszsz).tolist(),
(J2*mszsz).tolist(),
(-J1*exchange).tolist(),
(J2*exchange).tolist(),
]
bond_colors = [1,2,1,2]
H = nk.operator.GraphOperator(hi, g, bond_ops=bond_ops, bond_ops_colors=bond_colors)
return g, hi, HWe install all required libraries and configure JAX for stable high-precision computation. We define the J1–J2 frustrated Heisenberg Hamiltonian using a custom colored graph representation. We construct the Hilbert space and the GraphOperator to efficiently simulate interacting spin systems in NetKet.
class TransformerLogPsi(nn.Module):
L: int
d_model: int = 96
n_heads: int = 4
n_layers: int = 6
mlp_mult: int = 4
@nn.compact
def __call__(self, sigma):
x = (sigma > 0).astype(jnp.int32)
tok = nn.Embed(num_embeddings=2, features=self.d_model)(x)
pos = self.param("pos_embedding",
nn.initializers.normal(0.02),
(1, self.L, self.d_model))
h = tok + pos
for _ in range(self.n_layers):
h_norm = nn.LayerNorm()(h)
attn = nn.SelfAttention(
num_heads=self.n_heads,
qkv_features=self.d_model,
out_features=self.d_model,
)(h_norm)
h = h + attn
h2 = nn.LayerNorm()(h)
ff = nn.Dense(self.mlp_mult*self.d_model)(h2)
ff = nn.gelu(ff)
ff = nn.Dense(self.d_model)(ff)
h = h + ff
h = nn.LayerNorm()(h)
pooled = jnp.mean(h, axis=1)
out = nn.Dense(2)(pooled)
return out[...,0] + 1j*out[...,1]We implement a Transformer-based neural quantum state using Flax. We encode spin configurations into embeddings, apply multi-layer self-attention blocks, and aggregate global information through pooling. We output a complex log-amplitude, allowing our model to represent highly expressive many-body wavefunctions.
def structure_factor(vs, L):
samples = vs.samples
spins = samples.reshape(-1, L)
corr = np.zeros(L)
for r in range(L):
corr[r] = np.mean(spins[:,0] * spins[:,r])
q = np.arange(L) * 2*np.pi/L
Sq = np.abs(np.fft.fft(corr))
return q, Sq
def exact_energy(L, J2):
_, hi, H = make_j1j2_chain(L, J2, total_sz=0.0)
return nk.exact.lanczos_ed(H, k=1, compute_eigenvectors=False)[0]
def run_vmc(L, J2, n_iter=250):
g, hi, H = make_j1j2_chain(L, J2, total_sz=0.0)
model = TransformerLogPsi(L=L)
sampler = nk.sampler.MetropolisExchange(
hilbert=hi,
graph=g,
n_chains_per_rank=64
)
vs = nk.vqs.MCState(
sampler,
model,
n_samples=4096,
n_discard_per_chain=128
)
opt = nk.optimizer.Adam(learning_rate=2e-3)
sr = nk.optimizer.SR(diag_shift=1e-2)
vmc = nk.driver.VMC(H, opt, variational_state=vs, preconditioner=sr)
log = vmc.run(n_iter=n_iter, out=None)
energy = np.array(log["Energy"]["Mean"])
var = np.array(log["Energy"]["Variance"])
return vs, energy, varWe define the structure factor observable and the exact diagonalization benchmark for validation. We implement the full VMC training routine using MetropolisExchange sampling and Stochastic Reconfiguration. We return energy and variance arrays so that we can analyze convergence and physical accuracy.
L = 24
J2_values = np.linspace(0.0, 0.7, 6)
energies = []
structure_peaks = []
for J2 in tqdm(J2_values):
vs, e, var = run_vmc(L, J2)
energies.append(e[-1])
q, Sq = structure_factor(vs, L)
structure_peaks.append(np.max(Sq))L = 24
J2_values = np.linspace(0.0, 0.7, 6)
energies = []
structure_peaks = []
for J2 in tqdm(J2_values):
vs, e, var = run_vmc(L, J2)
energies.append(e[-1])
q, Sq = structure_factor(vs, L)
structure_peaks.append(np.max(Sq))We sweep across multiple J2 values to explore the frustrated phase diagram. We train a separate variational state for each coupling strength and record the final energy. We compute the structure factor peak for each point to detect possible ordering transitions.
L_ed = 14
J2_test = 0.5
E_ed = exact_energy(L_ed, J2_test)
vs_small, e_small, _ = run_vmc(L_ed, J2_test, n_iter=200)
E_vmc = e_small[-1]
print("ED Energy (L=14):", E_ed)
print("VMC Energy:", E_vmc)
print("Abs gap:", abs(E_vmc - E_ed))
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.plot(e_small)
plt.title("Energy Convergence")
plt.subplot(1,3,2)
plt.plot(J2_values, energies, 'o-')
plt.title("Energy vs J2")
plt.subplot(1,3,3)
plt.plot(J2_values, structure_peaks, 'o-')
plt.title("Structure Factor Peak")
plt.tight_layout()
plt.show()We benchmark our model against exact diagonalization on a smaller lattice size. We compute the absolute energy gap between VMC and ED to evaluate accuracy. We visualize convergence behavior, phase-energy trends, and structure-factor responses to summarize the physical insights we obtain.
In conclusion, we integrated advanced neural architectures with quantum Monte Carlo techniques to explore frustrated magnetism beyond the reach of small-system exact methods. We validated our Transformer ansatz against Lanczos diagonalization, analyzed convergence behavior, and extracted physically meaningful observables such as structure factor peaks to detect phase transitions. Also, we established a flexible foundation that we can extend toward higher-dimensional lattices, symmetry-projected states, entanglement diagnostics, and time-dependent quantum simulations.
Check out the Full Implementation Codes here. Also, feel free to follow us on Twitter and don’t forget to join our 130k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us







