Skip to content

Commit f7f54e2

Browse files
authored
Fix test flake in pauli_string_measurement_with_readout_mitigation_test (#7459)
Avoid possible KeyError in `test_process_pauli_measurement_results_raises_error_on_missing_calibration` Problem: Random Pauli strings generated in the test may not contain all qubits as is assumed in the `empty_calibration_result_dict` key. Solution: Make dictionary with the qubits present in Pauli strings. Fixes spurious test failure ``` pytest --randomly-seed=532866775 \ cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py \ -k test_process_pauli_measurement_results_raises_error_on_missing_calibration ``` Also fix - unnecessarily repeated list construction - unintentional dropping of circuits in a loop - Nit - sync docstring with function arguments
1 parent e34da07 commit f7f54e2

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def _build_many_one_qubits_empty_confusion_matrix(qubits_length: int) -> list[np
286286

287287

288288
def _process_pauli_measurement_results(
289-
qubits: list[ops.Qid],
289+
qubits: Sequence[ops.Qid],
290290
pauli_string_groups: list[list[ops.PauliString]],
291291
circuit_results: list[ResultDict],
292292
calibration_results: dict[tuple[ops.Qid, ...], SingleQubitReadoutCalibrationResult],
@@ -304,10 +304,11 @@ def _process_pauli_measurement_results(
304304
305305
Args:
306306
qubits: Qubits to build confusion matrices for. In a sorted order.
307-
pauli_strings: The lists of QWC Pauli string groups that are measured.
307+
pauli_string_groups: The lists of QWC Pauli string groups that are measured.
308308
circuit_results: A list of ResultDict obtained
309309
from running the Pauli measurement circuits.
310-
confusion_matrices: A list of confusion matrices from calibration results.
310+
calibration_results: A dictionary of SingleQubitReadoutCalibrationResult
311+
for tuples of qubits present in `pauli_string_groups`.
311312
pauli_repetitions: The number of repetitions used for Pauli string measurements.
312313
timestamp: The timestamp of the calibration results.
313314
disable_readout_mitigation: If set to True, returns no error-mitigated error
@@ -326,7 +327,7 @@ def _process_pauli_measurement_results(
326327

327328
calibration_result = (
328329
calibration_results[tuple(pauli_readout_qubits)]
329-
if disable_readout_mitigation is False
330+
if not disable_readout_mitigation
330331
else None
331332
)
332333

@@ -458,9 +459,9 @@ def measure_pauli_strings(
458459
qubits_list = sorted(unique_qubit_tuples)
459460

460461
# Build the basis-change circuits for each Pauli string group
461-
pauli_measurement_circuits = list[circuits.Circuit]()
462+
pauli_measurement_circuits: list[circuits.Circuit] = []
462463
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
463-
qid_list = list(sorted(input_circuit.all_qubits()))
464+
qid_list = sorted(input_circuit.all_qubits())
464465
basis_change_circuits = []
465466
input_circuit_unfrozen = input_circuit.unfreeze()
466467
for pauli_strings in pauli_string_groups:

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -874,23 +874,20 @@ def test_group_paulis_type_mismatch() -> None:
874874

875875
def test_process_pauli_measurement_results_raises_error_on_missing_calibration() -> None:
876876
"""Test that the function raises an error if the calibration result is missing."""
877-
qubits: list[cirq.Qid] = [q for q in cirq.LineQubit.range(5)]
877+
qubits: Sequence[cirq.Qid] = cirq.LineQubit.range(5)
878878

879879
measurement_op = cirq.measure(*qubits, key='m')
880-
test_circuits = list[cirq.Circuit]()
881-
for _ in range(3):
882-
circuit_list = []
883-
884-
circuit = _create_ghz(5, qubits) + measurement_op
885-
circuit_list.append(circuit)
886-
test_circuits.extend(circuit_list)
880+
test_circuits: list[cirq.Circuit] = [_create_ghz(5, qubits) + measurement_op for _ in range(3)]
887881

888882
pauli_strings = [_generate_random_pauli_string(qubits, True) for _ in range(3)]
889883
sampler = cirq.Simulator()
890884

891885
circuit_results = sampler.run_batch(test_circuits, repetitions=1000)
892886

893-
empty_calibration_result_dict = {tuple(qubits): None}
887+
pauli_strings_qubits = sorted(
888+
set(itertools.chain.from_iterable(ps.qubits for ps in pauli_strings))
889+
)
890+
empty_calibration_result_dict = {tuple(pauli_strings_qubits): None}
894891

895892
with pytest.raises(
896893
ValueError,

0 commit comments

Comments
 (0)