Skip to content

Commit 460dda0

Browse files
authored
Standardize serialization of CalibrationTag (#7480)
- Serialize CalibrationTag as a tag rather than use a custom token index field. - This tag was not converted over to the new Tag message since I didn't think anyone was still using it.
1 parent ffd78c8 commit 460dda0

File tree

7 files changed

+246
-191
lines changed

7 files changed

+246
-191
lines changed

cirq-google/cirq_google/api/v2/program.proto

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ message Circuit {
8181
// repeated moments.
8282
repeated int32 moment_indices = 3;
8383

84-
// Token that can be used to specify a version of a gate.
85-
// For instance, a gate that has been calibrated for a circuit.
86-
optional int32 token_constant_index = 4;
84+
// Deprecated field, do not use.
85+
reserved 4;
86+
8787
// Indices in the constant table for tags associated with the circuit
8888
repeated int32 tag_indices = 5;
8989
}
@@ -289,8 +289,8 @@ message Operation {
289289
// The token can be specified as a string or as a reference to
290290
// the constant table of the circuit.
291291
oneof token {
292-
string token_value = 4;
293-
int32 token_constant_index = 5;
292+
string token_value = 4 [deprecated = true];
293+
int32 token_constant_index = 5 [deprecated = true];
294294
}
295295

296296
// To be deprecated
@@ -342,6 +342,9 @@ message Tag {
342342
// Uses parameter model to interpolate FSim gate.
343343
FSimViaModelTag fsim_via_model = 7;
344344

345+
// Calibration Tag
346+
CalibrationTag calibration_tag = 9;
347+
345348
// Catch-all for all gates that do not fit into the
346349
// above tags.
347350
InternalTag internal_tag = 8;
@@ -380,6 +383,12 @@ message NoSyncTag {
380383
}
381384
}
382385

386+
// Tag to specify specific override tokens for operations or circuits.
387+
message CalibrationTag {
388+
// Token to serialize
389+
string token = 1;
390+
}
391+
383392
// Tag to represent any internal tags or tags not yet
384393
// implemented in the proto.
385394
message InternalTag {

cirq-google/cirq_google/api/v2/program_pb2.py

Lines changed: 138 additions & 132 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/program_pb2.pyi

Lines changed: 28 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/ops/calibration_tag.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any
1818

1919
import cirq
20+
from cirq_google.api.v2 import program_pb2
2021

2122

2223
class CalibrationTag:
@@ -48,3 +49,15 @@ def __eq__(self, other) -> bool:
4849

4950
def __hash__(self) -> int:
5051
return hash(self.token)
52+
53+
def to_proto(self, msg: program_pb2.Tag | None = None) -> program_pb2.Tag:
54+
if msg is None:
55+
msg = program_pb2.Tag()
56+
msg.calibration_tag.token = self.token
57+
return msg
58+
59+
@staticmethod
60+
def from_proto(msg: program_pb2.Tag) -> CalibrationTag:
61+
if msg.WhichOneof("tag") != "calibration_tag":
62+
raise ValueError(f"Message is not a CalibrationTag, {msg}")
63+
return CalibrationTag(token=msg.calibration_tag.token)

cirq-google/cirq_google/ops/calibration_tag_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
from __future__ import annotations
1616

17+
import pytest
18+
1719
import cirq
1820
import cirq_google
21+
from cirq_google.api.v2 import program_pb2
1922

2023

2124
def test_equality():
@@ -39,3 +42,14 @@ def test_str_repr():
3942
assert str(example_tag) == 'CalibrationTag(\'foo\')'
4043
assert repr(example_tag) == 'cirq_google.CalibrationTag(\'foo\')'
4144
cirq.testing.assert_equivalent_repr(example_tag, setup_code=('import cirq\nimport cirq_google'))
45+
46+
47+
def test_proto_serialization():
48+
tag = cirq_google.CalibrationTag('foo')
49+
msg = tag.to_proto()
50+
assert tag == cirq_google.CalibrationTag.from_proto(msg)
51+
52+
with pytest.raises(ValueError, match="Message is not a CalibrationTag"):
53+
msg = program_pb2.Tag()
54+
msg.fsim_via_model.SetInParent()
55+
cirq_google.CalibrationTag.from_proto(msg)

cirq-google/cirq_google/serialization/circuit_serializer.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -320,39 +320,22 @@ def _serialize_tag(
320320
raw_constants: dict[Any, int],
321321
):
322322
constant = v2.program_pb2.Constant()
323-
if isinstance(tag, CalibrationTag):
324-
constant.string_value = tag.token
325-
if tag.token in raw_constants:
326-
msg.token_constant_index = raw_constants[tag.token]
323+
if (tag_index := raw_constants.get(tag, None)) is None:
324+
if self.tag_serializer and self.tag_serializer.can_serialize_tag(tag):
325+
self.tag_serializer.to_proto(
326+
tag, msg=constant.tag_value, constants=constants, raw_constants=raw_constants
327+
)
328+
elif getattr(tag, 'to_proto', None) is not None:
329+
tag.to_proto(constant.tag_value) # type: ignore
327330
else:
328-
# Token not found, add it to the list
329-
msg.token_constant_index = len(constants)
331+
warnings.warn(f'Unrecognized Tag {tag}, not serializing.')
332+
if constant.WhichOneof('const_value'):
330333
constants.append(constant)
331334
if raw_constants is not None:
332-
raw_constants[tag.token] = msg.token_constant_index
335+
raw_constants[tag] = len(constants) - 1
336+
msg.tag_indices.append(len(constants) - 1)
333337
else:
334-
if isinstance(tag, DynamicalDecouplingTag):
335-
# TODO(dstrain): Remove this once we are deserializing tag indices everywhere.
336-
tag.to_proto(msg=msg.tags.add())
337-
if (tag_index := raw_constants.get(tag, None)) is None:
338-
if self.tag_serializer and self.tag_serializer.can_serialize_tag(tag):
339-
self.tag_serializer.to_proto(
340-
tag,
341-
msg=constant.tag_value,
342-
constants=constants,
343-
raw_constants=raw_constants,
344-
)
345-
elif getattr(tag, 'to_proto', None) is not None:
346-
tag.to_proto(constant.tag_value) # type: ignore
347-
else:
348-
warnings.warn(f'Unrecognized Tag {tag}, not serializing.')
349-
if constant.WhichOneof('const_value'):
350-
constants.append(constant)
351-
if raw_constants is not None:
352-
raw_constants[tag] = len(constants) - 1
353-
msg.tag_indices.append(len(constants) - 1)
354-
else:
355-
msg.tag_indices.append(tag_index)
338+
msg.tag_indices.append(tag_index)
356339

357340
def _serialize_circuit_op(
358341
self,
@@ -507,8 +490,6 @@ def _deserialize_circuit(
507490
for moment_index in circuit_proto.moment_indices:
508491
moments.append(deserialized_constants[moment_index])
509492

510-
if circuit_proto.HasField('token_constant_index'):
511-
tags.append(CalibrationTag(constants[circuit_proto.token_constant_index].string_value))
512493
for tag_index in circuit_proto.tag_indices:
513494
tags.append(deserialized_constants[tag_index])
514495
return cirq.Circuit(moments, tags=tags)
@@ -879,6 +860,8 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag):
879860
return PhysicalZTag()
880861
elif which == 'fsim_via_model':
881862
return FSimViaModelTag()
863+
elif which == 'calibration_tag':
864+
return CalibrationTag.from_proto(msg)
882865
elif which == 'internal_tag':
883866
return InternalTag.from_proto(msg)
884867
else:

cirq-google/cirq_google/serialization/circuit_serializer_test.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -582,18 +582,18 @@ def test_serialize_deserialize_circuit_with_tokens():
582582
op_q0_tag1 = v2.program_pb2.Operation()
583583
op_q0_tag1.xpowgate.exponent.float_value = 1.0
584584
op_q0_tag1.qubit_constant_index.append(0)
585-
op_q0_tag1.token_constant_index = 1
585+
op_q0_tag1.tag_indices.append(1)
586586

587587
op_q1_tag2 = v2.program_pb2.Operation()
588588
op_q1_tag2.xpowgate.exponent.float_value = 1.0
589589
op_q1_tag2.qubit_constant_index.append(3)
590-
op_q1_tag2.token_constant_index = 4
590+
op_q1_tag2.tag_indices.append(4)
591591

592592
# Test repeated tag uses existing constant entey
593593
op_q0_tag2 = v2.program_pb2.Operation()
594594
op_q0_tag2.xpowgate.exponent.float_value = 1.0
595595
op_q0_tag2.qubit_constant_index.append(0)
596-
op_q0_tag2.token_constant_index = 4
596+
op_q0_tag2.tag_indices.append(4)
597597

598598
proto = v2.program_pb2.Program(
599599
language=v2.program_pb2.Language(arg_function_language='exp', gate_set=_SERIALIZER_NAME),
@@ -602,10 +602,18 @@ def test_serialize_deserialize_circuit_with_tokens():
602602
),
603603
constants=[
604604
v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='2_4')),
605-
v2.program_pb2.Constant(string_value='abc123'),
605+
v2.program_pb2.Constant(
606+
tag_value=v2.program_pb2.Tag(
607+
calibration_tag=v2.program_pb2.CalibrationTag(token='abc123')
608+
)
609+
),
606610
v2.program_pb2.Constant(operation_value=op_q0_tag1),
607611
v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='2_5')),
608-
v2.program_pb2.Constant(string_value='def456'),
612+
v2.program_pb2.Constant(
613+
tag_value=v2.program_pb2.Tag(
614+
calibration_tag=v2.program_pb2.CalibrationTag(token='def456')
615+
)
616+
),
609617
v2.program_pb2.Constant(operation_value=op_q1_tag2),
610618
v2.program_pb2.Constant(moment_value=v2.program_pb2.Moment(operation_indices=[2, 5])),
611619
v2.program_pb2.Constant(operation_value=op_q0_tag2),
@@ -628,12 +636,14 @@ def test_serialize_deserialize_circuit_tags():
628636
proto = v2.program_pb2.Program(
629637
language=v2.program_pb2.Language(arg_function_language='exp', gate_set=_SERIALIZER_NAME),
630638
circuit=v2.program_pb2.Circuit(
631-
scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT,
632-
tag_indices=[1],
633-
token_constant_index=0,
639+
scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT, tag_indices=[0, 1]
634640
),
635641
constants=[
636-
v2.program_pb2.Constant(string_value="abc123"),
642+
v2.program_pb2.Constant(
643+
tag_value=v2.program_pb2.Tag(
644+
calibration_tag=v2.program_pb2.CalibrationTag(token="abc123")
645+
)
646+
),
637647
v2.program_pb2.Constant(
638648
tag_value=v2.program_pb2.Tag(
639649
internal_tag=v2.program_pb2.InternalTag(
@@ -697,7 +707,7 @@ def test_serialize_deserialize_circuit_with_subcircuit():
697707
op_tag = v2.program_pb2.Operation()
698708
op_tag.xpowgate.exponent.float_value = 1.0
699709
op_tag.qubit_constant_index.append(0)
700-
op_tag.token_constant_index = 1
710+
op_tag.tag_indices.append(1)
701711
op_symbol = v2.program_pb2.Operation()
702712
op_symbol.xpowgate.exponent.func.type = 'mul'
703713
op_symbol.xpowgate.exponent.func.args.add().arg_value.float_value = 2.0
@@ -726,7 +736,11 @@ def test_serialize_deserialize_circuit_with_subcircuit():
726736
),
727737
constants=[
728738
v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='2_5')),
729-
v2.program_pb2.Constant(string_value='abc123'),
739+
v2.program_pb2.Constant(
740+
tag_value=v2.program_pb2.Tag(
741+
calibration_tag=v2.program_pb2.CalibrationTag(token='abc123')
742+
)
743+
),
730744
v2.program_pb2.Constant(operation_value=op_tag),
731745
v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='2_4')),
732746
v2.program_pb2.Constant(operation_value=op_symbol),

0 commit comments

Comments
 (0)