Source code for pytket.passes.combinecondpass

# Copyright Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, cast

from pytket import unit_id
from pytket.circuit import BarrierOp, CircBox, Circuit, Command, Conditional

from .._tket.passes import BasePass, CustomPass


def _extract_cond(cmd: Command) -> tuple[int, list[Any]] | None:
    if isinstance(cmd.op, Conditional) and not isinstance(cmd.op.op, CircBox):
        return (cmd.op.value, cmd.args[: cmd.op.width])
    return None


def _append_cmd(circ: Circuit, cmd: Command) -> None:
    # if we were given a conditional, unwrap and append the inner op
    if isinstance(cmd.op, Conditional):
        the_op = cmd.op.op
        cond_args = cmd.op.width
    else:
        the_op = cmd.op
        cond_args = 0

    if isinstance(the_op, BarrierOp):
        circ.add_barrier(cmd.args[cond_args:], the_op.data)
    elif cmd.opgroup is not None:
        circ.add_gate(the_op, cmd.args[cond_args:], opgroup=cmd.opgroup)
    else:
        circ.add_gate(the_op, cmd.args[cond_args:])


def _emit_cond_box(
    top_circ: Circuit,
    sub_circ: Circuit,
    cond: tuple[int, list[Any]],
    max_wreg: int,
    max_rreg: int,
) -> None:
    # add WASM and RNG args
    if max_wreg > -1:
        sub_circ._add_w_register(max_wreg + 1)  # noqa: SLF001
    if max_rreg > -1:
        sub_circ._add_r_register(max_rreg + 1)  # noqa: SLF001

    cond_value = cond[0]
    cond_args = cond[1]
    if len(sub_circ.get_commands()) == 1:
        # if there was only one predicated op, don't emit a CircBox
        sub_cmd = sub_circ.get_commands()[0]
        top_circ.add_gate(
            sub_cmd.op,
            sub_cmd.args,
            condition_bits=cond_args,
            condition_value=cond_value,
        )
    else:
        sub_arg_list = sub_circ.qubits + sub_circ.bits
        top_circ.add_gate(
            CircBox(sub_circ),
            sub_arg_list,
            condition_bits=cond_args,
            condition_value=cond_value,
        )


def _combine_conditionals(circuit: Circuit) -> Circuit:  # noqa: PLR0912, PLR0915
    """Walk the sequence of commands in the circuit and combine contiguous subsequences
    of conditionals with the same predicate into conditional boxes. Note that the pass
    currently does not propagate opgroup names to the parent Boxes, but the group names
    should still be present on the gates within the box."""

    # the output circuit
    new_circuit = Circuit(0, circuit.name)
    new_circuit.add_phase(circuit.phase)
    # wasm_uid should get set automatically as we add WASM ops
    for qb in circuit.qubits:
        new_circuit.add_qubit(qb)
    for cb in circuit.bits:
        new_circuit.add_bit(cb)

    # the tuple of value and args describing the current conditional
    curr_cond = None
    # subcircuit for the current subsequence
    sub_circ = Circuit()
    # arg set for the current subsequence
    sub_args = set()
    # largest WASM/RNG ID seen in the total circuit/current subsequence
    max_wreg = -1
    max_rreg = -1
    max_sub_wreg = -1
    max_sub_rreg = -1
    # true if we need to emit the subcircuit before proceeding
    # due to a condition bit being used as an operand
    break_dep = False

    for cmd in circuit.get_commands():
        cond = _extract_cond(cmd)
        # if this is not part of the ongoing subsequence or we need to emit due to a
        # possible write to the predicate, emit the ongoing subsequence to the new circuit
        if curr_cond is not None and (break_dep or curr_cond != cond):
            _emit_cond_box(new_circuit, sub_circ, curr_cond, max_sub_wreg, max_sub_rreg)

            sub_circ = Circuit()
            sub_args.clear()
            max_sub_wreg = -1
            max_sub_rreg = -1
            curr_cond = None
            break_dep = False

        # if this is a conditional, add it to the ongoing subcircuit
        # otherwise, emit it directly.
        if cond is not None:
            cond_op = cast("Conditional", cmd.op)
            width = cond_op.width
            for arg in cmd.args[width:]:
                # this is overly conservative, because it will unnecessarily
                # break up reads of the predicate value. to do better we need
                # to distinguish the op's read and write operands somehow
                break_dep = break_dep or arg in cond[1]
                if arg not in sub_args:
                    if isinstance(arg, unit_id.Bit):
                        sub_circ.add_bit(arg)
                    elif isinstance(arg, unit_id.Qubit):
                        sub_circ.add_qubit(arg)
                    elif isinstance(arg, unit_id.WasmState):
                        reg_id_s = str(arg).split("[")[1].split("]")[0]
                        reg_id = int(reg_id_s)
                        max_wreg = max(max_wreg, reg_id)
                        max_sub_wreg = max(max_sub_wreg, reg_id)
                    elif isinstance(arg, unit_id.RngState):
                        reg_id_s = str(arg).split("[")[1].split("]")[0]
                        reg_id = int(reg_id_s)
                        max_rreg = max(max_rreg, reg_id)
                        max_sub_rreg = max(max_sub_rreg, reg_id)
                    else:
                        raise ValueError("Unknown arg type")
                    sub_args.add(arg)

            _append_cmd(sub_circ, cmd)
            curr_cond = cond
        else:
            _append_cmd(new_circuit, cmd)

    # emit final if necessary
    if curr_cond is not None:
        _emit_cond_box(new_circuit, sub_circ, curr_cond, max_sub_wreg, max_sub_rreg)

    # add WASM and RNG states if necessary
    if max_wreg > -1:
        new_circuit._add_w_register(max_wreg + 1)  # noqa: SLF001
    if max_rreg > -1:
        new_circuit._add_r_register(max_rreg + 1)  # noqa: SLF001

    return new_circuit


[docs] def CombineCondPass() -> BasePass: """Create a pass which combines contiguous groups of conditional gates with the same predicate into conditional boxes.""" return CustomPass(_combine_conditionals, label="combine_conditionals")