Source code for kosmos.dqc_scheduling.execution_scheduler

from dataclasses import dataclass

from qiskit import QuantumCircuit

from kosmos.dqc_scheduling.event import EventId
from kosmos.dqc_scheduling.remote_operation_detector import (
    RemoteOperationDetector,
    RemoteOperationInfo,
)
from kosmos.dqc_scheduling.space_time_matrix import SpaceTimeMatrix
from kosmos.partitioning.partition import Partition


[docs] @dataclass class ExecutionStep: """Step in execution order (partition, remote CNOT or teleportation). Attributes: event_id (EventId): Identifier for the event. dependencies (set[EventId]): Set of event IDs that must complete before this step. is_remote_gate (bool): Flag indicating whether this step involves a remote gate operation. Defaults to False. is_teleportation (bool): Flag indicating whether this step involves a teleportation operation. Defaults to False. circuit (QuantumCircuit | None): The quantum circuit for this execution step, or None if not applicable. Defaults to None. remote_operation_info (RemoteOperationInfo | None): Information about remote operation details, or None if not applicable. Defaults to None. """ event_id: EventId dependencies: set[EventId] is_remote_gate: bool = False is_teleportation: bool = False circuit: QuantumCircuit | None = None remote_operation_info: RemoteOperationInfo | None = None
[docs] class ExecutionScheduler: """Manages execution dependencies using space-time matrix. Attributes: space_time_matrix (SpaceTimeMatrix): Matrix storing assignment of qubits. remote_gates (list[RemoteOperationInfo]): List of detected remote gate operations. teleportations (list[RemoteOperationInfo]): List of detected qubit teleportations. """ def __init__(self, matrix: SpaceTimeMatrix) -> None: """Initialize Dependency Manager. Args: matrix (SpaceTimeMatrix): Matrix storing assignment of qubits. """ self.space_time_matrix: SpaceTimeMatrix = matrix self._partitions: dict[str, Partition] = matrix.partition_info detector = RemoteOperationDetector(matrix) self.remote_gates: list[RemoteOperationInfo] = detector.detect_remote_gates() self.teleportations: list[RemoteOperationInfo] = detector.detect_teleportations() self._execution_graph: dict[EventId, set[EventId]] = self._build_execution_graph() def _build_execution_graph(self) -> dict[EventId, set[EventId]]: """Build execution graph. Returns: dict[EventId, set[EventId]]: Graph containing dependencies to determine execution order. """ dag: dict[EventId, set[EventId]] = {} for partition_id, partition in self._partitions.items(): event_id = EventId(partition_id) deps: set[EventId] = set() qubits = partition.physical_qubits_used for other_id, other_partition in self._partitions.items(): if ( other_id != partition_id and other_partition.end_time is not None and any(q in other_partition.physical_qubits_used for q in qubits) and other_partition.end_time <= partition.start_time ): deps.add(EventId(other_id)) dag[event_id] = deps for i, remote_gate in enumerate(self.remote_gates): remote_gate_event_id = EventId(f"remote_gate_{i}") dag[EventId(f"remote_gate_{i}")] = set() for partition_id, partition in self._partitions.items(): affected_qubits = {remote_gate.control_qubit, remote_gate.target_qubit} if ( any(q in partition.physical_qubits_used for q in affected_qubits) and partition.start_time > remote_gate.gate_time ): dag[EventId(partition_id)].add(remote_gate_event_id) for i, teleportation in enumerate(self.teleportations): teleport_event_id = EventId(f"teleportation_{i}") dag[teleport_event_id] = set() for partition_id, partition in self._partitions.items(): partition_node = partition.get_node_for_qubit(teleportation.qubit) if ( teleportation.qubit in partition.physical_qubits_used and partition_node is not None and partition_node.id == teleportation.target_node.id and partition.start_time >= teleportation.gate_time ): dag[EventId(partition_id)].add(teleport_event_id) return dag
[docs] def get_execution_order(self) -> list[ExecutionStep]: """Return execution steps in topologically sorted order (Kahn's Algorithm). Returns: list[ExecutionStep]: Ordered list with timing information. """ all_events = {**self._execution_graph} in_degree = {eid: len(deps) for eid, deps in all_events.items()} queue = [eid for eid, degree in in_degree.items() if degree == 0] result = [] while queue: current_id = queue.pop(0) if current_id.value.startswith("partition_"): partition_id = current_id.value partition = self._partitions[partition_id] result.append( ExecutionStep( event_id=current_id, dependencies=all_events[current_id], is_remote_gate=False, is_teleportation=False, circuit=partition.circuit, ) ) elif current_id.value.startswith("remote_gate_"): gate_idx = int(current_id.value.split("_")[-1]) remote_gate = self.remote_gates[gate_idx] result.append( ExecutionStep( event_id=current_id, dependencies=all_events[current_id], is_remote_gate=True, is_teleportation=False, remote_operation_info=remote_gate, ) ) elif current_id.value.startswith("teleportation_"): teleport_idx = int(current_id.value.split("_")[-1]) teleportation = self.teleportations[teleport_idx] result.append( ExecutionStep( event_id=current_id, dependencies=all_events[current_id], is_remote_gate=False, is_teleportation=True, remote_operation_info=teleportation, ) ) for other_id, deps in all_events.items(): if current_id in deps: in_degree[other_id] -= 1 if in_degree[other_id] == 0: queue.append(other_id) if len(result) != len(all_events): msg = "Circular dependency detected." raise ValueError(msg) return result
@property def partitions(self) -> dict[str, Partition]: """Dictionary mapping partition IDs to their corresponding Partition objects.""" return self._partitions