diff --git a/cirq-core/cirq/circuits/insert_strategy.py b/cirq-core/cirq/circuits/insert_strategy.py index bcd9ee64a0b..589c8ada2cc 100644 --- a/cirq-core/cirq/circuits/insert_strategy.py +++ b/cirq-core/cirq/circuits/insert_strategy.py @@ -23,6 +23,16 @@ class InsertStrategy: INLINE: 'InsertStrategy' EARLIEST: 'InsertStrategy' + def __new__(cls, name: str, doc: str) -> 'InsertStrategy': + inst = getattr(cls, name, None) + if not inst or not isinstance(inst, cls): + inst = super().__new__(cls) + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self.name, self.__doc__) + def __init__(self, name: str, doc: str): self.name = name self.__doc__ = doc diff --git a/cirq-core/cirq/circuits/insert_strategy_test.py b/cirq-core/cirq/circuits/insert_strategy_test.py index 7387ff67b59..b1e69661e37 100644 --- a/cirq-core/cirq/circuits/insert_strategy_test.py +++ b/cirq-core/cirq/circuits/insert_strategy_test.py @@ -12,8 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle + +import pytest + import cirq def test_repr(): assert repr(cirq.InsertStrategy.NEW) == 'cirq.InsertStrategy.NEW' + + +@pytest.mark.parametrize( + 'strategy', + [ + cirq.InsertStrategy.NEW, + cirq.InsertStrategy.NEW_THEN_INLINE, + cirq.InsertStrategy.INLINE, + cirq.InsertStrategy.EARLIEST, + ], + ids=lambda strategy: strategy.name, +) +def test_identity_after_pickling(strategy: cirq.InsertStrategy): + unpickled_strategy = pickle.loads(pickle.dumps(strategy)) + assert unpickled_strategy is strategy