from typing import Any
import numpy as np
from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister
from kosmos.circuit_runner.qiskit_result import calculate_expectation_values
from kosmos.circuit_runner.qiskit_runner import AerSimulatorRunner
from kosmos.dqc_scheduling.assignment_strategies.assignment_strategy import AssignmentStrategy
from kosmos.dqc_scheduling.assignment_strategies.greedy_assignment import GreedyAssignment
from kosmos.dqc_scheduling.event import EventId
from kosmos.dqc_scheduling.event_queue import EventQueue
from kosmos.dqc_scheduling.execution_scheduler import ExecutionScheduler
from kosmos.dqc_scheduling.utils.protocol_scheduling_utils import (
schedule_circuit_execution,
schedule_remote_operation,
)
from kosmos.dqc_scheduling.utils.timing_utils import calculate_parallel_comm_time
from kosmos.partitioning.circuit_converter import to_gate_list
from kosmos.partitioning.partition import Partition
from kosmos.protocols.config.protocol import EGProtocolConfig
from kosmos.protocols.eg_protocol import EGProtocol
from kosmos.protocols.protocol_result import (
CircuitExecutionProtocolResult,
CommunicationProtocolResult,
)
from kosmos.quantum_logic.quantum_register_manager import QuantumRegisterManager
from kosmos.simulator.simulator import Simulator
from kosmos.topology.net import Network
from kosmos.topology.node import NodeId, QuantumNode
[docs]
class DQCSimulator(Simulator):
"""Distributed Quantum Computing simulator.
Executes quantum circuits in a distributed manner across multiple nodes.
Attributes:
quantum_manager (QuantumRegisterManager): Quantum register manager for handling
entanglements, qubits and states cross the network.
circuit_runner (AerSimulatorRunner): Qiskit Aer simulator for executing
quantum circuits.
assignment_strategy_cls (type[AssignmentStrategy]): Assignment strategy class.
eg_protocol_cls (type[EGProtocol]): Entanglement generation protocol class.
eg_protocol_config (EGProtocolConfig): Config for entanglement generation protocol.
execution_scheduler (ExecutionScheduler | None): Scheduler that determines the
execution order of circuit partition and remote operations.
execution_results (dict[str, Any]): Dictionary containing execution metrics
including timing, counts, and expectation values.
"""
MIN_EG_EVENT_ID_PARTS = 4
def __init__(
self,
network: Network,
seed: int = 1,
assignment_strategy_cls: type[AssignmentStrategy] = GreedyAssignment,
eg_protocol_cls: type[EGProtocol] = EGProtocol,
eg_protocol_config: EGProtocolConfig | None = None,
) -> None:
"""Initialize the DQC simulator.
Args:
network (Network): The quantum network topology.
seed (int): Random seed for reproducibility. Defaults to 1.
assignment_strategy_cls (type[AssignmentStrategy]): Assignment strategy class. Defaults
to GreedyAssignment.
eg_protocol_cls (type[EGProtocol]): EGProtocol class (or subclass) to use for
entanglement generation. Defaults to EGProtocol.
eg_protocol_config (EGProtocolConfig | None): Configuration for EGProtocol.
If None, a default EGProtocolConfig() is created. Defaults to None.
"""
super().__init__(network, seed)
self._validate_network_for_dqc()
self.quantum_manager = QuantumRegisterManager(network)
self._event_queue = EventQueue()
self.circuit_runner = AerSimulatorRunner()
self.assignment_strategy_cls = assignment_strategy_cls
self.eg_protocol_cls = eg_protocol_cls
self.eg_protocol_config = (
eg_protocol_config if eg_protocol_config is not None else EGProtocolConfig()
)
self._partition: Partition | None = None
self.execution_scheduler: ExecutionScheduler | None = None
self.execution_results: dict[str, Any] = {}
self._circuit_execution_result: CircuitExecutionProtocolResult | None = None
[docs]
def load_circuit(self, circuit: QuantumCircuit | str) -> None:
"""Partition and distribute a quantum circuit for distributed execution.
Args:
circuit (QuantumCircuit | str): Qiskit QuantumCircuit or QASM string.
"""
if isinstance(circuit, str):
circuit = QuantumCircuit.from_qasm_str(circuit)
elif not isinstance(circuit, QuantumCircuit):
msg = "Invalid circuit type. QASM and qiskit circuit are supported."
raise TypeError(msg)
num_gates = len(to_gate_list(circuit))
self._partition = Partition(
id="partition_0",
circuit=circuit,
logical_qubit_mapping={i: i for i in range(circuit.num_qubits)},
original_gate_indices=list(range(num_gates)),
)
assignment_strategy = self.assignment_strategy_cls(self.network)
space_time_matrix = assignment_strategy.allocate(self._partition)
self.execution_scheduler = ExecutionScheduler(space_time_matrix)
[docs]
def schedule_protocols(self) -> None:
"""Schedule communication & circuit execution protocols."""
if self.execution_scheduler is None:
msg = "No circuit loaded. Call load_circuit() first."
raise ValueError(msg)
execution_order = self.execution_scheduler.get_execution_order()
last_remote_op_per_node_pair: dict[tuple[NodeId, NodeId], EventId] = {}
# EventIds using comm qubits at a certain gate_time per node
node_comm_usage: dict[NodeId, dict[int, list[EventId]]] = {}
for step in execution_order:
if step.circuit is not None:
schedule_circuit_execution(
step=step,
event_queue=self._event_queue,
network=self.network,
circuit_runner=self.circuit_runner,
)
else:
schedule_remote_operation(
step=step,
event_queue=self._event_queue,
network=self.network,
quantum_manager=self.quantum_manager,
eg_protocol_cls=self.eg_protocol_cls,
eg_protocol_config=self.eg_protocol_config,
last_remote_op_per_node_pair=last_remote_op_per_node_pair,
node_comm_usage=node_comm_usage,
)
[docs]
def run(self) -> dict[str, Any]:
"""Execute all scheduled events.
Returns:
dict[str, Any]: Execution results including timing and measurements.
"""
if not self._event_queue.queue and not self._event_queue.waiting_events:
msg = "No events scheduled. Call schedule_protocols() first."
raise ValueError(msg)
self._event_queue.run()
operations_by_timestep: dict[int, dict[str, list[tuple[EventId, int]]]] = {}
for event_id, result in self._event_queue.event_results.items():
gate_time = self._get_gate_time_for_event(event_id)
if gate_time not in operations_by_timestep:
operations_by_timestep[gate_time] = {"computation": [], "communication": []}
if isinstance(result, CommunicationProtocolResult):
operations_by_timestep[gate_time]["communication"].append(
(event_id, result.execution_time)
)
elif isinstance(result, CircuitExecutionProtocolResult):
operations_by_timestep[gate_time]["computation"].append(
(event_id, result.execution_time)
)
self._circuit_execution_result = result
total_time = 0
computation_time = 0
communication_time = 0
for gate_time in sorted(operations_by_timestep.keys()):
ops = operations_by_timestep[gate_time]
comp_time = max((t for _, t in ops["computation"]), default=0)
comm_time = calculate_parallel_comm_time(
ops["communication"], self.network, self._get_nodes_for_event
)
timestep_duration = comp_time + comm_time
total_time += timestep_duration
computation_time += comp_time
communication_time += comm_time
counts, density_matrix = self.measure_final_state()
self.execution_results = {
"computation_time": computation_time,
"communication_time": communication_time,
"total_time": total_time,
"final_time": self._event_queue.current_time,
"num_remote_gates": len(self.execution_scheduler.remote_gates)
if self.execution_scheduler
else 0,
"num_teleportations": len(self.execution_scheduler.teleportations)
if self.execution_scheduler
else 0,
"counts": counts,
"expectation_values": calculate_expectation_values(counts),
"density_matrix": density_matrix,
}
return self.execution_results
[docs]
def measure_final_state(self) -> tuple[dict[str, int], np.ndarray]:
"""Measure the final quantum state after circuit partition finished.
Returns:
tuple[dict[str, int], np.ndarray]: Raw measurement counts (bitstring -> count).
"""
if (
not self._partition
or not self._partition.circuit
or self._circuit_execution_result is None
):
return {}, np.array([])
density_matrix = self._circuit_execution_result.density_matrix
if density_matrix is None:
return {}, np.array([])
num_qubits = self._partition.circuit.num_qubits
qr = QuantumRegister(num_qubits, "q")
cr = ClassicalRegister(num_qubits, "c")
measure_circuit = QuantumCircuit(qr, cr)
measure_circuit.set_density_matrix(density_matrix)
measure_circuit.measure(qr, cr)
job = self.circuit_runner.backend.run(measure_circuit, shots=self.circuit_runner.num_shots)
result = job.result()
return result.get_counts(0), density_matrix
def _validate_network_for_dqc(self) -> None:
"""Validate that the network is suitable for distributed quantum computing."""
nodes_without_comm_qubits = [
node.id.value
for node in self.network.nodes()
if isinstance(node, QuantumNode) and node.communication_qubits == 0
]
if nodes_without_comm_qubits:
msg = (
f"Distributed quantum computing requires all quantum nodes to have "
f"communication qubits. The following nodes have no "
f"communication qubits: {', '.join(nodes_without_comm_qubits)}"
)
raise ValueError(msg)
def _get_gate_time_for_event(self, event_id: EventId) -> int:
"""Get the gate_time (timestep) for an event from the space-time matrix.
Args:
event_id (EventId): Event identifier.
Returns:
int: Gate timestep where the gate is executed, retrieved from the space-time matrix.
"""
if event_id.value.startswith("partition_"):
partition = self.execution_scheduler.partitions[event_id.value]
return partition.start_time
if event_id.value.startswith("remote_gate_"):
idx = int(event_id.value.split("_")[-1])
return self.execution_scheduler.remote_gates[idx].gate_time
if event_id.value.startswith("teleportation_"):
idx = int(event_id.value.split("_")[-1])
return self.execution_scheduler.teleportations[idx].gate_time
if event_id.value.startswith("eg_"):
parts = event_id.value.split("_")
if len(parts) >= self.MIN_EG_EVENT_ID_PARTS:
remote_op_id = "_".join(parts[3:])
return self._get_gate_time_for_event(EventId(remote_op_id))
return 0
def _get_nodes_for_event(self, event_id: EventId) -> list[NodeId]:
"""Get the nodes involved in an event.
Args:
event_id (EventId): Event identifier.
Returns:
list[NodeId]: Involved node identifiers regarding the provided event.
"""
if event_id.value.startswith("eg_"):
parts = event_id.value.split("_")
if len(parts) >= self.MIN_EG_EVENT_ID_PARTS:
return [NodeId(f"{parts[1]}_{parts[2]}"), NodeId(f"{parts[3]}_{parts[4]}")]
elif event_id.value.startswith("remote_gate_") or event_id.value.startswith(
"teleportation_"
):
if event_id.value.startswith("remote_gate_"):
idx = int(event_id.value.split("_")[-1])
remote_op = self.execution_scheduler.remote_gates[idx]
else:
idx = int(event_id.value.split("_")[-1])
remote_op = self.execution_scheduler.teleportations[idx]
return [remote_op.source_node.id, remote_op.target_node.id]
return []