kosmos.ml.models.vqc.circuit.qiskit_circuit.autograd_function

Classes

class AutogradCtx

Bases: Protocol

Autograd context for the QiskitAutogradFunction.


class QiskitAutogradFunction(*args, **kwargs)

Bases: torch.autograd.Function

Custom autograd bridge between PyTorch and Qiskit for variational quantum circuits.


Methods

forward(ctx: torch.autograd.function.FunctionCtx, x: torch.Tensor, weights: torch.Tensor, evaluator: collections.abc.Callable, gradient_fn: collections.abc.Callable) torch.Tensor

Perform the forward pass by evaluating the quantum circuit via the provided evaluator.

Parameters:
  • ctx (FunctionCtx) – Autograd context to save information for the backward pass.

  • x (torch.Tensor) – Input batch of shape (B, input_dim).

  • weights (torch.Tensor) – Trainable weights of the circuit.

  • evaluator (Callable) – Callable that evaluates the quantum circuit and returns the output values for the given inputs and weights.

  • gradient_fn (Callable) – Callable computing gradients of the circuit output with respect to its parameters, e.g., via the parameter-shift rule.

Returns:

Output batch of shape (B, output_dim).

Return type:

torch.Tensor

backward(ctx: AutogradCtx, *grad_outputs: torch.Tensor) tuple[None, torch.Tensor, None, None]

Compute gradients using the provided quantum gradient function.

Parameters:
  • ctx (AutogradCtx) – Autograd context with saved tensors from the forward pass.

  • *grad_outputs (torch.Tensor) – Gradient of the loss with respect to the forward output, shape (B, output_dim).

Returns:

Gradients for each forward input. Only the

gradient with respect to weights is returned; gradients for model and x are None.

Return type:

tuple[None, torch.Tensor, None, None]