Source code for kosmos.ml.models.vqc.circuit.qiskit_circuit.autograd_function

from collections.abc import Callable
from typing import Protocol

import numpy as np
import torch
from torch.autograd.function import FunctionCtx


[docs] class AutogradCtx(Protocol): """Autograd context for the QiskitAutogradFunction.""" saved_tensors: tuple[torch.Tensor, ...] gradient_fn: Callable[[torch.Tensor, np.ndarray], torch.Tensor]
[docs] class QiskitAutogradFunction(torch.autograd.Function): """Custom autograd bridge between PyTorch and Qiskit for variational quantum circuits."""
[docs] @staticmethod def forward( ctx: FunctionCtx, x: torch.Tensor, weights: torch.Tensor, evaluator: Callable, gradient_fn: Callable, ) -> torch.Tensor: """Perform the forward pass by evaluating the quantum circuit via the provided evaluator. Args: ctx (FunctionCtx): Autograd context to save information for the backward pass. x (torch.Tensor): Input batch of shape (B, input_dim). weights (torch.Tensor): Trainable weights of the circuit. evaluator (Callable): Callable that evaluates the quantum circuit and returns the output values for the given inputs and weights. gradient_fn (Callable): Callable computing gradients of the circuit output with respect to its parameters, e.g., via the parameter-shift rule. Returns: torch.Tensor: Output batch of shape (B, output_dim). """ ctx.save_for_backward(x, weights) ctx.gradient_fn = gradient_fn ctx.evaluator = evaluator return evaluator(x, weights.detach().cpu().numpy())
[docs] @staticmethod def backward( ctx: AutogradCtx, *grad_outputs: torch.Tensor ) -> tuple[None, torch.Tensor, None, None]: """Compute gradients using the provided quantum gradient function. Args: ctx (AutogradCtx): Autograd context with saved tensors from the forward pass. *grad_outputs (torch.Tensor): Gradient of the loss with respect to the forward output, shape (B, output_dim). Returns: tuple[None, torch.Tensor, None, None]: Gradients for each forward input. Only the gradient with respect to `weights` is returned; gradients for `model` and `x` are None. """ x, weights = ctx.saved_tensors (grad_output,) = grad_outputs # Compute dOutput/dWeights via parameter-shift jac = ctx.gradient_fn(x, weights.detach().cpu().numpy()) jac_np = jac.detach().cpu().numpy() # shape: (B, output_dim, *weights.shape) grad_out_np = grad_output.detach().cpu().numpy() # shape: (B, output_dim) # Chain rule: dL/dW = sum(dL/dOut * dOut/dW) grad_w = np.tensordot(grad_out_np, jac_np, axes=([0, 1], [0, 1])) return None, torch.from_numpy(grad_w).float(), None, None