Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 65 additions & 27 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
)
from cirq_google.ops.calibration_tag import CalibrationTag
from cirq_google.experimental.ops import CouplerPulse
from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs
from cirq_google.serialization import (
serializer,
op_deserializer,
op_serializer,
arg_func_langs,
tag_serializer,
tag_deserializer,
)

# The name used in program.proto to identify the serializer as CircuitSerializer.
# "v2.5" refers to the most current v2.Program proto format.
Expand Down Expand Up @@ -64,6 +71,8 @@ class CircuitSerializer(serializer.Serializer):
deserialization of this field is deployed.
op_serializer: Optional custom serializer for serializing unknown gates.
op_deserializer: Optional custom deserializer for deserializing unknown gates.
tag_serializer: Optional custom serializer for serializing unknown tags.
tag_deserializer: Optional custom deserializer for deserializing unknown tags.
"""

def __init__(
Expand All @@ -72,13 +81,17 @@ def __init__(
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
op_serializer: Optional[op_serializer.OpSerializer] = None,
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
tag_serializer: Optional[tag_serializer.TagSerializer] = None,
tag_deserializer: Optional[tag_deserializer.TagDeserializer] = None,
):
"""Construct the circuit serializer object."""
super().__init__(gate_set_name=_SERIALIZER_NAME)
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
self.op_serializer = op_serializer
self.op_deserializer = op_deserializer
self.tag_serializer = tag_serializer
self.tag_deserializer = tag_deserializer

def serialize(
self, program: cirq.AbstractCircuit, msg: Optional[v2.program_pb2.Program] = None
Expand Down Expand Up @@ -301,8 +314,8 @@ def _serialize_gate_op(
msg.qubit_constant_index.append(raw_constants[qubit])

for tag in op.tags:
constant = v2.program_pb2.Constant()
if isinstance(tag, CalibrationTag):
constant = v2.program_pb2.Constant()
constant.string_value = tag.token
if tag.token in raw_constants:
msg.token_constant_index = raw_constants[tag.token]
Expand All @@ -317,16 +330,22 @@ def _serialize_gate_op(
# TODO(dstrain): Remove this once we are deserializing tag indices everywhere.
tag.to_proto(msg=msg.tags.add())
if (tag_index := raw_constants.get(tag, None)) is None:
constant = v2.program_pb2.Constant()
tag_index = len(constants)
if getattr(tag, 'to_proto', None) is not None:
if self.tag_serializer and self.tag_serializer.can_serialize_tag(tag):
self.tag_serializer.to_proto(
tag,
msg=constant.tag_value,
constants=constants,
raw_constants=raw_constants,
)
elif getattr(tag, 'to_proto', None) is not None:
tag.to_proto(constant.tag_value) # type: ignore
constants.append(constant)
if raw_constants is not None:
raw_constants[tag] = tag_index
msg.tag_indices.append(tag_index)
else:
warnings.warn(f'Unrecognized Tag {tag}, not serializing.')
if constant.WhichOneof('const_value'):
constants.append(constant)
if raw_constants is not None:
raw_constants[tag] = len(constants) - 1
msg.tag_indices.append(len(constants) - 1)
else:
msg.tag_indices.append(tag_index)
return msg
Expand Down Expand Up @@ -434,7 +453,18 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
)
)
elif which_const == 'tag_value':
deserialized_constants.append(self._deserialize_tag(constant.tag_value))
if self.tag_deserializer and self.tag_deserializer.can_deserialize_proto(
constant.tag_value
):
deserialized_constants.append(
self.tag_deserializer.from_proto(
constant.tag_value,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
)
else:
deserialized_constants.append(self._deserialize_tag(constant.tag_value))
else:
msg = f'Unrecognized constant type {which_const}, ignoring.' # pragma: no cover
warnings.warn(msg) # pragma: no cover
Expand Down Expand Up @@ -490,22 +520,7 @@ def _deserialize_moment(
gate_op = self._deserialize_gate_op(
op, constants=constants, deserialized_constants=deserialized_constants
)
if op.tag_indices:
tags = [
deserialized_constants[tag_index]
for tag_index in op.tag_indices
if deserialized_constants[tag_index] not in gate_op.tags
and deserialized_constants[tag_index] is not None
]
else:
tags = []
for tag in op.tags:
if (
tag not in gate_op.tags
and (new_tag := self._deserialize_tag(tag)) is not None
):
tags.append(new_tag)
moment_ops.append(gate_op.with_tags(*tags))
moment_ops.append(gate_op)
for op in moment_proto.circuit_operations:
moment_ops.append(
self._deserialize_circuit_op(
Expand Down Expand Up @@ -768,7 +783,30 @@ def _deserialize_gate_op(
elif which == 'token_value':
op = op.with_tags(CalibrationTag(operation_proto.token_value))

return op
# Add tags to op
if operation_proto.tag_indices and deserialized_constants is not None:
tags = [
deserialized_constants[tag_index]
for tag_index in operation_proto.tag_indices
if deserialized_constants[tag_index] not in op.tags
and deserialized_constants[tag_index] is not None
]
else:
tags = []
for tag in operation_proto.tags:
if tag not in op.tags:
if self.tag_deserializer and self.tag_deserializer.can_deserialize_proto(tag):
tags.append(
self.tag_deserializer.from_proto(
tag,
constants=constants or [],
deserialized_constants=deserialized_constants or [],
)
)
elif (new_tag := self._deserialize_tag(tag)) is not None:
tags.append(new_tag)

return op.with_tags(*tags)

def _deserialize_circuit_op(
self,
Expand Down
102 changes: 98 additions & 4 deletions cirq-google/cirq_google/serialization/circuit_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional
import pytest

import attrs
import numpy as np
import sympy
from google.protobuf import json_format
Expand All @@ -27,6 +28,8 @@
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
from cirq_google.serialization.op_deserializer import OpDeserializer
from cirq_google.serialization.op_serializer import OpSerializer
from cirq_google.serialization.tag_deserializer import TagDeserializer
from cirq_google.serialization.tag_serializer import TagSerializer


class FakeDevice(cirq.Device):
Expand Down Expand Up @@ -916,10 +919,10 @@ def test_backwards_compatibility_with_old_tags():
),
constants=[v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='1_1'))],
)
expected_circuit_no_tag = cirq.Circuit(
expected_circuit = cirq.Circuit(
cirq.X(cirq.GridQubit(1, 1)).with_tags(cg.ops.DynamicalDecouplingTag(protocol='X'))
)
assert cg.CIRCUIT_SERIALIZER.deserialize(circuit_proto) == expected_circuit_no_tag
assert cg.CIRCUIT_SERIALIZER.deserialize(circuit_proto) == expected_circuit


def test_circuit_with_units():
Expand Down Expand Up @@ -949,7 +952,7 @@ def can_serialize_operation(self, op):

def to_proto(
self,
op: cirq.CircuitOperation,
op: cirq.Operation,
msg: Optional[v2.program_pb2.CircuitOperation] = None,
*,
constants: List[v2.program_pb2.Constant],
Expand Down Expand Up @@ -1008,7 +1011,7 @@ def test_serdes_preserves_syc():


@pytest.mark.parametrize('use_constants_table', [True, False])
def test_custom_serializer(use_constants_table: bool):
def test_custom_op_serializer(use_constants_table: bool):
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
Expand All @@ -1026,6 +1029,97 @@ def test_custom_serializer(use_constants_table: bool):
assert op.qubits == (cirq.q(0, 0),)


@attrs.frozen
class DiscountTag:
discount: float


class DiscountTagSerializer(TagSerializer):
"""Describes how to serialize DiscountTag."""

def can_serialize_tag(self, tag):
return isinstance(tag, DiscountTag)

def to_proto(
self,
tag: Any,
msg: Optional[v2.program_pb2.Tag] = None,
*,
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> v2.program_pb2.Tag:
assert isinstance(tag, DiscountTag)
if msg is None:
msg = v2.program_pb2.Tag() # pragma: nocover
msg.internal_tag.tag_name = 'Discount'
msg.internal_tag.tag_package = 'test'
msg.internal_tag.tag_args['discount'].arg_value.float_value = tag.discount
return msg


class DiscountTagDeserializer(TagDeserializer):
"""Describes how to serialize CircuitOperations."""

def can_deserialize_proto(self, proto):
return (
proto.WhichOneof("tag") == "internal_tag"
and proto.internal_tag.tag_name == 'Discount'
and proto.internal_tag.tag_package == 'test'
)

def from_proto(
self,
proto: v2.program_pb2.Operation,
*,
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> DiscountTag:
return DiscountTag(discount=proto.internal_tag.tag_args["discount"].arg_value.float_value)


@pytest.mark.parametrize('use_constants_table', [True, False])
def test_custom_tag_serializer(use_constants_table: bool):
c = cirq.Circuit(cirq.X(cirq.q(0, 0)).with_tags(DiscountTag(0.25)))
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
tag_serializer=DiscountTagSerializer(),
tag_deserializer=DiscountTagDeserializer(),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
moment = deserialized_circuit[0]
assert len(moment) == 1
op = moment[cirq.q(0, 0)]
assert len(op.tags) == 1
assert isinstance(op.tags[0], DiscountTag)
assert op.tags[0].discount == 0.25


def test_custom_tag_serializer_with_tags_outside_constants():
op_tag = v2.program_pb2.Operation()
op_tag.xpowgate.exponent.float_value = 1.0
op_tag.qubit_constant_index.append(0)
tag = v2.program_pb2.Tag()
tag.internal_tag.tag_name = 'Discount'
tag.internal_tag.tag_package = 'test'
tag.internal_tag.tag_args['discount'].arg_value.float_value = 0.5
op_tag.tags.append(tag)
circuit_proto = v2.program_pb2.Program(
language=v2.program_pb2.Language(arg_function_language='exp', gate_set=_SERIALIZER_NAME),
circuit=v2.program_pb2.Circuit(
scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT,
moments=[v2.program_pb2.Moment(operations=[op_tag])],
),
constants=[v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='1_1'))],
)
expected_circuit = cirq.Circuit(cirq.X(cirq.GridQubit(1, 1)).with_tags(DiscountTag(0.50)))
serializer = cg.CircuitSerializer(
tag_serializer=DiscountTagSerializer(), tag_deserializer=DiscountTagDeserializer()
)
assert serializer.deserialize(circuit_proto) == expected_circuit


def test_reset_gate_with_improper_argument():
serializer = cg.CircuitSerializer()

Expand Down
51 changes: 51 additions & 0 deletions cirq-google/cirq_google/serialization/tag_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List

import abc

from cirq_google.api import v2


class TagDeserializer(abc.ABC):
"""Generic supertype for tag deserializers.

Each tag deserializer describes how to deserialize a specific
set of tag protos.
"""

@abc.abstractmethod
def can_deserialize_proto(self, proto: v2.program_pb2.Tag) -> bool:
"""Whether the given tag can be serialized by this serializer."""

@abc.abstractmethod
def from_proto(
self,
proto: v2.program_pb2.Tag,
*,
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> Any:
"""Converts a proto-formatted operation into a Cirq operation.

Args:
proto: The proto object to be deserialized.
constants: The list of Constant protos referenced by constant
table indices in `proto`.
deserialized_constants: The deserialized contents of `constants`.

Returns:
The deserialized operation represented by `proto`.
"""
Loading