Source code for kosmos.dqc_scheduling.execution_scheduler
from dataclasses import dataclass
from qiskit import QuantumCircuit
from kosmos.dqc_scheduling.event import EventId
from kosmos.dqc_scheduling.remote_operation_detector import (
RemoteOperationDetector,
RemoteOperationInfo,
)
from kosmos.dqc_scheduling.space_time_matrix import SpaceTimeMatrix
from kosmos.partitioning.partition import Partition
[docs]
@dataclass
class ExecutionStep:
"""Step in execution order (partition, remote CNOT or teleportation).
Attributes:
event_id (EventId): Identifier for the event.
dependencies (set[EventId]): Set of event IDs that must complete before this step.
is_remote_gate (bool): Flag indicating whether this step involves a remote gate operation.
Defaults to False.
is_teleportation (bool): Flag indicating whether this step involves a teleportation
operation. Defaults to False.
circuit (QuantumCircuit | None): The quantum circuit for this execution step,
or None if not applicable. Defaults to None.
remote_operation_info (RemoteOperationInfo | None): Information about remote operation
details, or None if not applicable. Defaults to None.
"""
event_id: EventId
dependencies: set[EventId]
is_remote_gate: bool = False
is_teleportation: bool = False
circuit: QuantumCircuit | None = None
remote_operation_info: RemoteOperationInfo | None = None
[docs]
class ExecutionScheduler:
"""Manages execution dependencies using space-time matrix.
Attributes:
space_time_matrix (SpaceTimeMatrix): Matrix storing assignment of qubits.
remote_gates (list[RemoteOperationInfo]): List of detected remote gate operations.
teleportations (list[RemoteOperationInfo]): List of detected qubit teleportations.
"""
def __init__(self, matrix: SpaceTimeMatrix) -> None:
"""Initialize Dependency Manager.
Args:
matrix (SpaceTimeMatrix): Matrix storing assignment of qubits.
"""
self.space_time_matrix: SpaceTimeMatrix = matrix
self._partitions: dict[str, Partition] = matrix.partition_info
detector = RemoteOperationDetector(matrix)
self.remote_gates: list[RemoteOperationInfo] = detector.detect_remote_gates()
self.teleportations: list[RemoteOperationInfo] = detector.detect_teleportations()
self._execution_graph: dict[EventId, set[EventId]] = self._build_execution_graph()
def _build_execution_graph(self) -> dict[EventId, set[EventId]]:
"""Build execution graph.
Returns:
dict[EventId, set[EventId]]: Graph containing dependencies to determine execution
order.
"""
dag: dict[EventId, set[EventId]] = {}
for partition_id, partition in self._partitions.items():
event_id = EventId(partition_id)
deps: set[EventId] = set()
qubits = partition.physical_qubits_used
for other_id, other_partition in self._partitions.items():
if (
other_id != partition_id
and other_partition.end_time is not None
and any(q in other_partition.physical_qubits_used for q in qubits)
and other_partition.end_time <= partition.start_time
):
deps.add(EventId(other_id))
dag[event_id] = deps
for i, remote_gate in enumerate(self.remote_gates):
remote_gate_event_id = EventId(f"remote_gate_{i}")
dag[EventId(f"remote_gate_{i}")] = set()
for partition_id, partition in self._partitions.items():
affected_qubits = {remote_gate.control_qubit, remote_gate.target_qubit}
if (
any(q in partition.physical_qubits_used for q in affected_qubits)
and partition.start_time > remote_gate.gate_time
):
dag[EventId(partition_id)].add(remote_gate_event_id)
for i, teleportation in enumerate(self.teleportations):
teleport_event_id = EventId(f"teleportation_{i}")
dag[teleport_event_id] = set()
for partition_id, partition in self._partitions.items():
partition_node = partition.get_node_for_qubit(teleportation.qubit)
if (
teleportation.qubit in partition.physical_qubits_used
and partition_node is not None
and partition_node.id == teleportation.target_node.id
and partition.start_time >= teleportation.gate_time
):
dag[EventId(partition_id)].add(teleport_event_id)
return dag
[docs]
def get_execution_order(self) -> list[ExecutionStep]:
"""Return execution steps in topologically sorted order (Kahn's Algorithm).
Returns:
list[ExecutionStep]: Ordered list with timing information.
"""
all_events = {**self._execution_graph}
in_degree = {eid: len(deps) for eid, deps in all_events.items()}
queue = [eid for eid, degree in in_degree.items() if degree == 0]
result = []
while queue:
current_id = queue.pop(0)
if current_id.value.startswith("partition_"):
partition_id = current_id.value
partition = self._partitions[partition_id]
result.append(
ExecutionStep(
event_id=current_id,
dependencies=all_events[current_id],
is_remote_gate=False,
is_teleportation=False,
circuit=partition.circuit,
)
)
elif current_id.value.startswith("remote_gate_"):
gate_idx = int(current_id.value.split("_")[-1])
remote_gate = self.remote_gates[gate_idx]
result.append(
ExecutionStep(
event_id=current_id,
dependencies=all_events[current_id],
is_remote_gate=True,
is_teleportation=False,
remote_operation_info=remote_gate,
)
)
elif current_id.value.startswith("teleportation_"):
teleport_idx = int(current_id.value.split("_")[-1])
teleportation = self.teleportations[teleport_idx]
result.append(
ExecutionStep(
event_id=current_id,
dependencies=all_events[current_id],
is_remote_gate=False,
is_teleportation=True,
remote_operation_info=teleportation,
)
)
for other_id, deps in all_events.items():
if current_id in deps:
in_degree[other_id] -= 1
if in_degree[other_id] == 0:
queue.append(other_id)
if len(result) != len(all_events):
msg = "Circular dependency detected."
raise ValueError(msg)
return result
@property
def partitions(self) -> dict[str, Partition]:
"""Dictionary mapping partition IDs to their corresponding Partition objects."""
return self._partitions