Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 47 additions & 50 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,31 @@
_SQRT_ISWAP_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])
_SQRT_ISWAP_INV_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])
_CZ_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.CZ])
_SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC)
_SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP)
_SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV)
_CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ)


# TODO(#5050) Add GlobalPhaseGate
# Target gates of `cirq_google.GoogleCZTargetGateset`.
_CZ_TARGET_GATES = [_CZ_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
_CZ_TARGET_GATES = [
_CZ_FSIM_GATE_FAMILY,
_CZ_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
# Target gates of `cirq_google.SycamoreTargetGateset`.
_SYC_TARGET_GATES = [_SYC_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
_SYC_TARGET_GATES = [
_SYC_FSIM_GATE_FAMILY,
_SYC_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
# Target gates of `cirq.SqrtIswapTargetGateset`
_SQRT_ISWAP_TARGET_GATES = [
_SQRT_ISWAP_FSIM_GATE_FAMILY,
_SQRT_ISWAP_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
Expand All @@ -77,51 +92,44 @@ class _GateRepresentations:

Attributes:
gate_spec_name: The name of gate type in `GateSpecification`.
deserialized_forms: Gate representations to be included when the corresponding
`GateSpecification` gate type is deserialized into gatesets and gate durations.
serializable_forms: GateFamilies used to check whether a given gate can be serialized to the
gate type in this _GateRepresentation.
supported_gates: A list of gates that can be serialized into the `GateSpecification` with
the matching name.
"""

gate_spec_name: str
deserialized_forms: List[GateOrFamily]
serializable_forms: List[cirq.GateFamily]
supported_gates: List[cirq.GateFamily]


# Gates recognized by the GridDevice class. This controls the (de)serialization between
# `DeviceSpecification.valid_gates` and `cirq.Gateset`.

"""Valid gates for a GridDevice."""
# This is a superset of valid gates for a given `GridDevice` instance. The specific gateset depends
# on the underlying device.

# Edit this list to add support for new gates. If a new `_GateRepresentations` is added, add a new
# `GateSpecification` message in cirq-google/cirq_google/api/v2/device.proto.

# Update `_build_compilation_target_gatesets()` if the gate you are updating affects an existing
# CompilationTargetGateset there, or if you'd like to add another `CompilationTargetGateset` to
# allow users to transform their circuits that include your gate.
_GATES: List[_GateRepresentations] = [
_GateRepresentations(
gate_spec_name='syc',
deserialized_forms=[_SYC_FSIM_GATE_FAMILY],
serializable_forms=[_SYC_FSIM_GATE_FAMILY, cirq.GateFamily(ops.SYC)],
gate_spec_name='syc', supported_gates=[_SYC_FSIM_GATE_FAMILY, _SYC_GATE_FAMILY]
),
_GateRepresentations(
gate_spec_name='sqrt_iswap',
deserialized_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY],
serializable_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP)],
supported_gates=[_SQRT_ISWAP_FSIM_GATE_FAMILY, _SQRT_ISWAP_GATE_FAMILY],
),
_GateRepresentations(
gate_spec_name='sqrt_iswap_inv',
deserialized_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY],
serializable_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP_INV)],
supported_gates=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, _SQRT_ISWAP_INV_GATE_FAMILY],
),
_GateRepresentations(
gate_spec_name='cz',
deserialized_forms=[_CZ_FSIM_GATE_FAMILY],
serializable_forms=[_CZ_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.CZ)],
gate_spec_name='cz', supported_gates=[_CZ_FSIM_GATE_FAMILY, _CZ_GATE_FAMILY]
),
_GateRepresentations(
gate_spec_name='phased_xz',
deserialized_forms=[
cirq.PhasedXZGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.PhasedXPowGate,
cirq.HPowGate,
cirq.GateFamily(cirq.I),
cirq.ops.SingleQubitCliffordGate,
],
serializable_forms=[
supported_gates=[
# TODO: Extend support to cirq.IdentityGate.
cirq.GateFamily(cirq.I),
cirq.GateFamily(cirq.PhasedXZGate),
Expand All @@ -134,29 +142,20 @@ class _GateRepresentations:
),
_GateRepresentations(
gate_spec_name='virtual_zpow',
deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
supported_gates=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
),
_GateRepresentations(
gate_spec_name='physical_zpow',
deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
supported_gates=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
),
_GateRepresentations(
gate_spec_name='coupler_pulse',
deserialized_forms=[experimental_ops.CouplerPulse],
serializable_forms=[cirq.GateFamily(experimental_ops.CouplerPulse)],
),
_GateRepresentations(
gate_spec_name='meas',
deserialized_forms=[cirq.MeasurementGate],
serializable_forms=[cirq.GateFamily(cirq.MeasurementGate)],
supported_gates=[cirq.GateFamily(experimental_ops.CouplerPulse)],
),
_GateRepresentations(
gate_spec_name='wait',
deserialized_forms=[cirq.WaitGate],
serializable_forms=[cirq.GateFamily(cirq.WaitGate)],
gate_spec_name='meas', supported_gates=[cirq.GateFamily(cirq.MeasurementGate)]
),
_GateRepresentations(gate_spec_name='wait', supported_gates=[cirq.GateFamily(cirq.WaitGate)]),
]


Expand Down Expand Up @@ -216,7 +215,7 @@ def _serialize_gateset_and_gate_durations(
for gate_family in gateset.gates:
gate_spec = v2.device_pb2.GateSpecification()
gate_rep = next(
(gr for gr in _GATES for gf in gr.serializable_forms if gf == gate_family), None
(gr for gr in _GATES for gf in gr.supported_gates if gf == gate_family), None
)
if gate_rep is None:
raise ValueError(f'Unrecognized gate: {gate_family}.')
Expand All @@ -228,13 +227,13 @@ def _serialize_gateset_and_gate_durations(
# Set gate duration
gate_durations_picos = {
int(gate_durations[gf].total_picos())
for gf in gate_rep.serializable_forms
for gf in gate_rep.supported_gates
if gf in gate_durations
}
if len(gate_durations_picos) > 1:
raise ValueError(
'Multiple gate families in the following list exist in the gate duration dict, and '
f'they are expected to have the same duration value: {gate_rep.serializable_forms}'
f'they are expected to have the same duration value: {gate_rep.supported_gates}'
)
elif len(gate_durations_picos) == 1:
gate_spec.gate_duration_picos = gate_durations_picos.pop()
Expand Down Expand Up @@ -269,10 +268,8 @@ def _deserialize_gateset_and_gate_durations(
)
continue

gates_list.extend(gate_rep.deserialized_forms)
for g in gate_rep.deserialized_forms:
if not isinstance(g, cirq.GateFamily):
g = cirq.GateFamily(g)
gates_list.extend(gate_rep.supported_gates)
for g in gate_rep.supported_gates:
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos)

# TODO(#5050) Add GlobalPhaseGate support
Expand Down
32 changes: 23 additions & 9 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,26 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.ops.phased_x_z_gate.PhasedXZGate,
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.GateFamily(cirq.CZ),
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate),
cirq.GateFamily(cirq.ops.common_gates.XPowGate),
cirq.GateFamily(cirq.ops.common_gates.YPowGate),
cirq.GateFamily(cirq.I),
cirq.ops.SingleQubitCliffordGate,
cirq.ops.HPowGate,
cirq.ops.phased_x_gate.PhasedXPowGate,
cirq.GateFamily(cirq.ops.SingleQubitCliffordGate),
cirq.GateFamily(cirq.ops.HPowGate),
cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
),
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
cirq.ops.measurement_gate.MeasurementGate,
cirq.ops.wait_gate.WaitGate,
cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse),
cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate),
cirq.GateFamily(cirq.ops.wait_gate.WaitGate),
)

base_duration = cirq.Duration(picos=1_000)
Expand All @@ -113,6 +117,10 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3,
cirq.GateFamily(cirq_google.SYC): base_duration * 0,
cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1,
cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2,
cirq.GateFamily(cirq.CZ): base_duration * 3,
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,
Expand All @@ -139,6 +147,9 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.common_gates.HPowGate,
Expand All @@ -161,6 +172,9 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.GateFamily(cirq.CZ),
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.common_gates.HPowGate,
Expand Down