Skip to content

Commit f6187b3

Browse files
authored
Better stimcirq serialization (#7192)
* Better stimcirq serialization - After some testing, the serialization for stimcirq was a bit subpar. Serializing the json was rather bulky. - This splits the stimcirq serialization and deserialization into its own file, so its more encapsulated. - This also serializes each operation directly, rather than jamming the whole json into the gate. - For a typical QEC surface code circuit, this reduces the size of the proto by about 25%. * Fix coverage and typing. * try again for typecheck * Fix merge issues. * sort imports * Add serializer unit tests.
1 parent b16aeda commit f6187b3

File tree

6 files changed

+342
-43
lines changed

6 files changed

+342
-43
lines changed

cirq-google/cirq_google/serialization/circuit_serializer.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
op_deserializer,
3939
op_serializer,
4040
serializer,
41+
stimcirq_deserializer,
42+
stimcirq_serializer,
4143
tag_deserializer,
4244
tag_serializer,
4345
)
@@ -47,9 +49,6 @@
4749
# CircuitSerializer is the dedicated serializer for the v2.5 format.
4850
_SERIALIZER_NAME = 'v2_5'
4951

50-
# Package name for stimcirq
51-
_STIMCIRQ_MODULE = "stimcirq"
52-
5352

5453
class CircuitSerializer(serializer.Serializer):
5554
"""A class for serializing and deserializing programs and operations.
@@ -93,6 +92,8 @@ def __init__(
9392
self.op_deserializer = op_deserializer
9493
self.tag_serializer = tag_serializer
9594
self.tag_deserializer = tag_deserializer
95+
self.stimcirq_serializer = stimcirq_serializer.StimCirqSerializer()
96+
self.stimcirq_deserializer = stimcirq_deserializer.StimCirqDeserializer()
9697

9798
def serialize(
9899
self, program: cirq.AbstractCircuit, msg: Optional[v2.program_pb2.Program] = None
@@ -160,6 +161,10 @@ def _serialize_circuit(
160161
self.op_serializer.to_proto(
161162
op, op_pb, constants=constants, raw_constants=raw_constants
162163
)
164+
elif self.stimcirq_serializer.can_serialize_operation(op):
165+
self.stimcirq_serializer.to_proto(
166+
op, op_pb, constants=constants, raw_constants=raw_constants
167+
)
163168
else:
164169
self._serialize_gate_op(
165170
op, op_pb, constants=constants, raw_constants=raw_constants
@@ -174,6 +179,10 @@ def _serialize_circuit(
174179
self.op_serializer.to_proto(
175180
op, op_pb, constants=constants, raw_constants=raw_constants
176181
)
182+
elif self.stimcirq_serializer.can_serialize_operation(op):
183+
self.stimcirq_serializer.to_proto(
184+
op, op_pb, constants=constants, raw_constants=raw_constants
185+
)
177186
else:
178187
self._serialize_gate_op(
179188
op, op_pb, constants=constants, raw_constants=raw_constants
@@ -277,30 +286,6 @@ def _serialize_gate_op(
277286
arg_func_langs.float_arg_to_proto(
278287
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
279288
)
280-
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
281-
gate, "__module__", ""
282-
).startswith(_STIMCIRQ_MODULE):
283-
# Special handling for stimcirq objects, which can be both operations and gates.
284-
stimcirq_obj = (
285-
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
286-
)
287-
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
288-
# All stimcirq gates currently have _json_dict_defined
289-
msg.internalgate.name = type(stimcirq_obj).__name__
290-
msg.internalgate.module = _STIMCIRQ_MODULE
291-
if isinstance(stimcirq_obj, cirq.Gate):
292-
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
293-
else:
294-
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)
295-
296-
# Store json_dict objects in gate_args
297-
for k, v in stimcirq_obj._json_dict_().items():
298-
arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
299-
else:
300-
# New stimcirq op without a json dict has been introduced
301-
raise ValueError(
302-
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
303-
) # pragma: no cover
304289
else:
305290
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')
306291

@@ -438,6 +423,12 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
438423
constants=proto.constants,
439424
deserialized_constants=deserialized_constants,
440425
)
426+
elif self.stimcirq_deserializer.can_deserialize_proto(constant.operation_value):
427+
op_pb = self.stimcirq_deserializer.from_proto(
428+
constant.operation_value,
429+
constants=proto.constants,
430+
deserialized_constants=deserialized_constants,
431+
)
441432
else:
442433
op_pb = self._deserialize_gate_op(
443434
constant.operation_value,
@@ -517,6 +508,10 @@ def _deserialize_moment(
517508
gate_op = self.op_deserializer.from_proto(
518509
op, constants=constants, deserialized_constants=deserialized_constants
519510
)
511+
elif self.stimcirq_deserializer.can_deserialize_proto(op):
512+
gate_op = self.stimcirq_deserializer.from_proto(
513+
op, constants=constants, deserialized_constants=deserialized_constants
514+
)
520515
else:
521516
gate_op = self._deserialize_gate_op(
522517
op, constants=constants, deserialized_constants=deserialized_constants
@@ -718,20 +713,7 @@ def _deserialize_gate_op(
718713
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
719714
elif which_gate_type == 'internalgate':
720715
msg = operation_proto.internalgate
721-
if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers():
722-
# special handling for stimcirq
723-
# Use JSON resolver to instantiate the object
724-
kwargs = {}
725-
for k, v in msg.gate_args.items():
726-
arg = arg_func_langs.arg_from_proto(v)
727-
if arg is not None:
728-
kwargs[k] = arg
729-
op = _stimcirq_json_resolvers()[msg.name](**kwargs)
730-
if qubits:
731-
op = op(*qubits)
732-
else:
733-
# all other internal gates
734-
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
716+
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
735717
elif which_gate_type == 'couplerpulsegate':
736718
gate = CouplerPulse(
737719
hold_time=cirq.Duration(

cirq-google/cirq_google/serialization/circuit_serializer_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,12 +1162,41 @@ def test_reset_gate_with_no_dimension():
11621162
assert reset_circuit == cirq.Circuit(cirq.R(cirq.q(1, 2)))
11631163

11641164

1165-
def test_stimcirq_gates():
1165+
@pytest.mark.parametrize('use_constants_table', [True, False])
1166+
def test_stimcirq_gates(use_constants_table: bool):
11661167
stimcirq = pytest.importorskip("stimcirq")
1167-
serializer = cg.CircuitSerializer()
1168+
serializer = cg.CircuitSerializer(
1169+
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
1170+
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
1171+
)
11681172
q = cirq.q(1, 2)
11691173
q2 = cirq.q(2, 2)
11701174
c = cirq.Circuit(
1175+
cirq.Moment(
1176+
stimcirq.CumulativeObservableAnnotation(parity_keys=["m"], observable_index=123)
1177+
),
1178+
cirq.Moment(
1179+
stimcirq.MeasureAndOrResetGate(
1180+
measure=True,
1181+
reset=False,
1182+
basis='Z',
1183+
invert_measure=True,
1184+
key='mmm',
1185+
measure_flip_probability=0.125,
1186+
)(q2)
1187+
),
1188+
cirq.Moment(stimcirq.ShiftCoordsAnnotation([1.0, 2.0])),
1189+
cirq.Moment(
1190+
stimcirq.SweepPauli(stim_sweep_bit_index=2, cirq_sweep_symbol='t', pauli=cirq.X)(q)
1191+
),
1192+
cirq.Moment(
1193+
stimcirq.SweepPauli(stim_sweep_bit_index=3, cirq_sweep_symbol='y', pauli=cirq.Y)(q)
1194+
),
1195+
cirq.Moment(
1196+
stimcirq.SweepPauli(stim_sweep_bit_index=4, cirq_sweep_symbol='t', pauli=cirq.Z)(q)
1197+
),
1198+
cirq.Moment(stimcirq.TwoQubitAsymmetricDepolarizingChannel([0.05] * 15)(q, q2)),
1199+
cirq.Moment(stimcirq.CZSwapGate()(q, q2)),
11711200
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
11721201
cirq.Moment(cirq.measure(q, key="m")),
11731202
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, List
17+
18+
import cirq
19+
from cirq_google.api import v2
20+
from cirq_google.serialization import arg_func_langs
21+
from cirq_google.serialization.op_deserializer import OpDeserializer
22+
23+
24+
@functools.cache
25+
def _stimcirq_json_resolvers():
26+
"""Retrieves stimcirq JSON resolvers if stimcirq is installed.
27+
Returns an empty dict if not installed."""
28+
try:
29+
import stimcirq
30+
31+
return stimcirq.JSON_RESOLVERS_DICT
32+
except ModuleNotFoundError: # pragma: no cover
33+
return {} # pragma: no cover
34+
35+
36+
class StimCirqDeserializer(OpDeserializer):
37+
"""Describes how to serialize CircuitOperations."""
38+
39+
def can_deserialize_proto(self, proto: v2.program_pb2.Operation):
40+
return (
41+
proto.WhichOneof('gate_value') == 'internalgate'
42+
and proto.internalgate.module == 'stimcirq'
43+
)
44+
45+
def from_proto(
46+
self,
47+
proto: v2.program_pb2.Operation,
48+
*,
49+
constants: List[v2.program_pb2.Constant],
50+
deserialized_constants: List[Any],
51+
) -> cirq.Operation:
52+
"""Turns a cirq_google Operation proto into a stimcirq object.
53+
54+
Args:
55+
proto: The proto object to be deserialized.
56+
constants: The list of Constant protos referenced by constant
57+
table indices in `proto`. This list should already have been
58+
parsed to produce 'deserialized_constants'.
59+
deserialized_constants: The deserialized contents of `constants`.
60+
61+
Returns:
62+
The deserialized stimcirq object
63+
64+
Raises:
65+
ValueError: If stimcirq is not installed or the object is not recognized.
66+
"""
67+
resolvers = _stimcirq_json_resolvers()
68+
cls_name = proto.internalgate.name
69+
70+
if cls_name not in resolvers:
71+
raise ValueError(f"stimcirq object {proto} not recognized. (Is stimcirq installed?)")
72+
73+
# Resolve each of the serialized arguments
74+
kwargs: Dict[str, Any] = {}
75+
for k, v in proto.internalgate.gate_args.items():
76+
if k == "pauli":
77+
# Special Handling for pauli gate
78+
pauli = v.arg_value.string_value
79+
if pauli == "X":
80+
kwargs[k] = cirq.X
81+
elif pauli == "Y":
82+
kwargs[k] = cirq.Y
83+
elif pauli == "Z":
84+
kwargs[k] = cirq.Z
85+
else:
86+
raise ValueError(f"Unknown stimcirq pauli Gate {v}")
87+
continue
88+
89+
arg = arg_func_langs.arg_from_proto(v)
90+
if arg is not None:
91+
kwargs[k] = arg
92+
93+
# Instantiate the class from the stimcirq resolvers
94+
op = resolvers[cls_name](**kwargs)
95+
96+
# If this operation has qubits, add them
97+
qubits = [deserialized_constants[q] for q in proto.qubit_constant_index]
98+
if qubits:
99+
op = op(*qubits)
100+
101+
return op
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from cirq_google.api import v2
18+
from cirq_google.serialization.stimcirq_deserializer import StimCirqDeserializer
19+
20+
21+
def test_bad_stimcirq_op():
22+
proto = v2.program_pb2.Operation()
23+
proto.internalgate.module = 'stimcirq'
24+
proto.internalgate.name = 'WolfgangPauli'
25+
26+
with pytest.raises(ValueError, match='not recognized'):
27+
_ = StimCirqDeserializer().from_proto(proto, constants=[], deserialized_constants=[])
28+
29+
30+
def test_bad_pauli_gate():
31+
proto = v2.program_pb2.Operation()
32+
proto.internalgate.module = 'stimcirq'
33+
proto.internalgate.name = 'SweepPauli'
34+
proto.internalgate.gate_args['stim_sweep_bit_index'].arg_value.float_value = 1.0
35+
proto.internalgate.gate_args['cirq_sweep_symbol'].arg_value.string_value = 't'
36+
proto.internalgate.gate_args['pauli'].arg_value.string_value = 'Q'
37+
38+
with pytest.raises(ValueError, match='pauli'):
39+
_ = StimCirqDeserializer().from_proto(proto, constants=[], deserialized_constants=[])

0 commit comments

Comments
 (0)