Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- #277: Methods for pretty-printing `Pattern`: `to_ascii`,
`to_unicode`, `to_latex`.

- #300: Branch selection in simulation: in addition to
`RandomBranchSelector` which corresponds to the strategy that was
already implemented, the user can use `FixedBranchSelector`,
`ConstBranchSelector`, or define a custom branch selection by
deriving the abstract class `BranchSelector`.

- #312: The separation between `TensorNetworkBackend` and backends
that operate on a full-state representation, such as
`StatevecBackend` and `DensityMatrixBackend`, is now clearer with
Expand Down Expand Up @@ -61,6 +67,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- #277: The method `Pattern.print_pattern` is now deprecated.

- #300: `pr_calc` parameter is removed in back-end initializers.
The user can specify `pr_calc` in the constructor of
`RandomBranchSelector` instead.

- #300: `rng` is no longer stored in the backends; it is now passed as
an optional argument to each simulation method.

- #261: Moved all device interface functionalities to an external
library and removed their implementation from this library.

Expand Down
30 changes: 29 additions & 1 deletion docs/source/simulator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Pattern Simulation

.. automethod:: run


Simulator backends
++++++++++++++++++

Expand Down Expand Up @@ -49,3 +48,32 @@ Density Matrix

.. autoclass:: DensityMatrix
:members:

Branch Selection: :mod:`graphix.branch_selector` module
+++++++++++++++++++++++++++++++++++++++++++++++++++++++

.. currentmodule:: graphix.branch_selector

Abstract Branch Selector
------------------------

.. autoclass:: BranchSelector
:members:

Random Branch Selector
----------------------

.. autoclass:: RandomBranchSelector
:members:

Fixed Branch Selector
---------------------

.. autoclass:: FixedBranchSelector
:members:

Constant Branch Selector
------------------------

.. autoclass:: ConstBranchSelector
:members:
149 changes: 149 additions & 0 deletions graphix/branch_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Branch selector.

Branch selectors determine the computation branch that is explored
during a simulation, meaning the choice of measurement outcomes. The
branch selection can be random (see :class:`RandomBranchSelector`) or
deterministic (see :class:`ConstBranchSelector`).

"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar

from typing_extensions import override

from graphix.measurements import Outcome, outcome
from graphix.rng import ensure_rng

if TYPE_CHECKING:
from collections.abc import Callable

from numpy.random import Generator


class BranchSelector(ABC):
"""Abstract class for branch selectors.

A branch selector provides the method `measure`, which returns the
measurement outcome (0 or 1) for a given qubit.
"""

@abstractmethod
def measure(self, qubit: int, f_expectation0: Callable[[], float], rng: Generator | None = None) -> Outcome:
"""Return the measurement outcome of ``qubit``.

Parameters
----------
qubit : int
Index of qubit to measure

f_expectation0 : Callable[[], float]
A function that the method can use to retrieve the expected
probability of outcome 0. The probability is computed only if
this function is called (lazy computation), ensuring no
unnecessary computational cost.

rng: Generator, optional
Random-number generator for measurements.
This generator is used only in case of random branch selection
(see :class:`RandomBranchSelector`).
If ``None``, a default random-number generator is used.
Default is ``None``.
"""


@dataclass
class RandomBranchSelector(BranchSelector):
"""Random branch selector.

Parameters
----------
pr_calc : bool, optional
Whether to compute the probability distribution before selecting the measurement result.
If ``False``, measurements yield 0/1 with equal probability (50% each).
Default is ``True``.
"""

pr_calc: bool = True

@override
def measure(self, qubit: int, f_expectation0: Callable[[], float], rng: Generator | None = None) -> Outcome:
"""
Return the measurement outcome of ``qubit``.

If ``pr_calc`` is ``True``, the measurement outcome is determined based on the
computed probability of outcome 0. Otherwise, the result is randomly chosen
with a 50% chance for either outcome.
"""
rng = ensure_rng(rng)
if self.pr_calc:
prob_0 = f_expectation0()
return outcome(rng.random() > prob_0)
result: Outcome = rng.choice([0, 1])
return result


_T = TypeVar("_T", bound=Mapping[int, Outcome])


@dataclass
class FixedBranchSelector(BranchSelector, Generic[_T]):
"""Branch selector with predefined measurement outcomes.

The mapping is fixed in ``results``. By default, an error is raised if
a qubit is measured without a predefined outcome. However, another
branch selector can be specified in ``default`` to handle such cases.

Parameters
----------
results : Mapping[int, bool]
A dictionary mapping qubits to their measurement outcomes.
If a qubit is not present in this mapping, the ``default`` branch
selector is used.
default : BranchSelector | None, optional
Branch selector to use for qubits not present in ``results``.
If ``None``, an error is raised when an unmapped qubit is measured.
Default is ``None``.
"""

results: _T
default: BranchSelector | None = None

@override
def measure(self, qubit: int, f_expectation0: Callable[[], float], rng: Generator | None = None) -> Outcome:
"""
Return the predefined measurement outcome of ``qubit``, if available.

If the qubit is not present in ``results``, the ``default`` branch selector
is used. If no default is provided, an error is raised.
"""
result = self.results.get(qubit)
if result is None:
if self.default is None:
raise ValueError(f"Unexpected measurement of qubit {qubit}.")
return self.default.measure(qubit, f_expectation0)
return result


@dataclass
class ConstBranchSelector(BranchSelector):
"""Branch selector with a constant measurement outcome.

The value ``result`` is returned for every qubit.

Parameters
----------
result : Outcome
The fixed measurement outcome for all qubits.
"""

result: Outcome

@override
def measure(self, qubit: int, f_expectation0: Callable[[], float], rng: Generator | None = None) -> Outcome:
"""Return the constant measurement outcome ``result`` for any qubit."""
return self.result
5 changes: 5 additions & 0 deletions graphix/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def outcome(b: bool) -> Outcome:
return 1 if b else 0


def toggle_outcome(outcome: Outcome) -> Outcome:
"""Toggle outcome."""
return 1 if outcome == 0 else 0


@dataclasses.dataclass
class Domains:
"""Represent `X^sZ^t` where s and t are XOR of results from given sets of indices."""
Expand Down
27 changes: 15 additions & 12 deletions graphix/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from graphix.fundamentals import Axis, Plane, Sign
from graphix.gflow import find_flow, find_gflow, get_layers
from graphix.graphsim import GraphState
from graphix.measurements import Outcome, PauliMeasurement
from graphix.measurements import Outcome, PauliMeasurement, toggle_outcome
from graphix.pretty_print import OutputFormat, pattern_to_str
from graphix.simulator import PatternSimulator
from graphix.states import BasicStates
Expand All @@ -32,7 +32,9 @@
if TYPE_CHECKING:
from collections.abc import Container, Iterator, Mapping
from collections.abc import Set as AbstractSet
from typing import Any, Literal
from typing import Any

from numpy.random import Generator

from graphix.parameter import ExpressionOrFloat, ExpressionOrSupportsFloat, Parameter
from graphix.sim import Backend, BackendState, Data
Expand Down Expand Up @@ -1355,7 +1357,11 @@ def space_list(self) -> list[int]:
return n_list

def simulate_pattern(
self, backend: Backend[_StateT_co] | str = "statevector", input_state: Data = BasicStates.PLUS, **kwargs: Any
self,
backend: Backend[_StateT_co] | str = "statevector",
input_state: Data = BasicStates.PLUS,
rng: Generator | None = None,
**kwargs: Any,
) -> BackendState:
"""Simulate the execution of the pattern by using :class:`graphix.simulator.PatternSimulator`.

Expand All @@ -1365,6 +1371,10 @@ def simulate_pattern(
----------
backend : str
optional parameter to select simulator backend.
rng: Generator, optional
Random-number generator for measurements.
This generator is used only in case of random branch selection
(see :class:`RandomBranchSelector`).
kwargs: keyword args for specified backend.

Returns
Expand All @@ -1375,7 +1385,7 @@ def simulate_pattern(
.. seealso:: :class:`graphix.simulator.PatternSimulator`
"""
sim = PatternSimulator(self, backend=backend, **kwargs)
sim.run(input_state)
sim.run(input_state, rng=rng)
return sim.backend.state

def perform_pauli_measurements(self, leave_input: bool = False, ignore_pauli_with_deps: bool = False) -> None:
Expand Down Expand Up @@ -1887,14 +1897,7 @@ def extract_signal(plane: Plane, s_domain: set[int], t_domain: set[int]) -> Extr
assert_never(plane)


def toggle_outcome(outcome: Literal[0, 1]) -> Literal[0, 1]:
"""Toggle outcome."""
if outcome == 0:
return 1
return 0


def shift_outcomes(outcomes: dict[int, Literal[0, 1]], signal_dict: dict[int, set[int]]) -> dict[int, Literal[0, 1]]:
def shift_outcomes(outcomes: dict[int, Outcome], signal_dict: dict[int, set[int]]) -> dict[int, Outcome]:
"""Update outcomes with shifted signals.

Shifted signals (as returned by the method
Expand Down
Loading
Loading