Skip to content

Commit e2de439

Browse files
authored
Remove [float,int] unions from type declarations (#7042)
* Clean up redundant complex type unions * format * fix numpy ufunc, one mypy err * remove float/int unions * Fix merge * Remove test * lint
1 parent 7f46121 commit e2de439

18 files changed

+59
-96
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __eq__(self, other) -> bool:
216216
and all(m0 == m1 for m0, m1 in zip(self.moments, other.moments))
217217
)
218218

219-
def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
219+
def _approx_eq_(self, other: Any, atol: float) -> bool:
220220
"""See `cirq.protocols.SupportsApproximateEquality`."""
221221
if not isinstance(other, AbstractCircuit):
222222
return NotImplemented

cirq-core/cirq/circuits/moment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def __eq__(self, other) -> bool:
349349

350350
return self is other or self._sorted_operations_() == other._sorted_operations_()
351351

352-
def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
352+
def _approx_eq_(self, other: Any, atol: float) -> bool:
353353
"""See `cirq.protocols.SupportsApproximateEquality`."""
354354
if not isinstance(other, type(self)):
355355
return NotImplemented

cirq-core/cirq/circuits/text_diagram_drawer.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
Sequence,
2626
Tuple,
2727
TYPE_CHECKING,
28-
Union,
2928
)
3029

3130
import numpy as np
@@ -45,23 +44,11 @@
4544

4645
_HorizontalLine = NamedTuple(
4746
'_HorizontalLine',
48-
[
49-
('y', Union[int, float]),
50-
('x1', Union[int, float]),
51-
('x2', Union[int, float]),
52-
('emphasize', bool),
53-
('doubled', bool),
54-
],
47+
[('y', float), ('x1', float), ('x2', float), ('emphasize', bool), ('doubled', bool)],
5548
)
5649
_VerticalLine = NamedTuple(
5750
'_VerticalLine',
58-
[
59-
('x', Union[int, float]),
60-
('y1', Union[int, float]),
61-
('y2', Union[int, float]),
62-
('emphasize', bool),
63-
('doubled', bool),
64-
],
51+
[('x', float), ('y1', float), ('y2', float), ('emphasize', bool), ('doubled', bool)],
6552
)
6653
_DiagramText = NamedTuple('_DiagramText', [('text', str), ('transposed_text', str)])
6754

@@ -99,10 +86,10 @@ def __init__(
9986
self.vertical_lines: List[_VerticalLine] = (
10087
[] if vertical_lines is None else list(vertical_lines)
10188
)
102-
self.horizontal_padding: Dict[int, Union[int, float]] = (
89+
self.horizontal_padding: Dict[int, float] = (
10390
dict() if horizontal_padding is None else dict(horizontal_padding)
10491
)
105-
self.vertical_padding: Dict[int, Union[int, float]] = (
92+
self.vertical_padding: Dict[int, float] = (
10693
dict() if vertical_padding is None else dict(vertical_padding)
10794
)
10895

@@ -171,24 +158,14 @@ def grid_line(
171158
raise ValueError("Line is neither horizontal nor vertical")
172159

173160
def vertical_line(
174-
self,
175-
x: Union[int, float],
176-
y1: Union[int, float],
177-
y2: Union[int, float],
178-
emphasize: bool = False,
179-
doubled: bool = False,
161+
self, x: float, y1: float, y2: float, emphasize: bool = False, doubled: bool = False
180162
) -> None:
181163
"""Adds a line from (x, y1) to (x, y2)."""
182164
y1, y2 = sorted([y1, y2])
183165
self.vertical_lines.append(_VerticalLine(x, y1, y2, emphasize, doubled))
184166

185167
def horizontal_line(
186-
self,
187-
y: Union[int, float],
188-
x1: Union[int, float],
189-
x2: Union[int, float],
190-
emphasize: bool = False,
191-
doubled: bool = False,
168+
self, y: float, x1: float, x2: float, emphasize: bool = False, doubled: bool = False
192169
) -> None:
193170
"""Adds a line from (x1, y) to (x2, y)."""
194171
x1, x2 = sorted([x1, x2])
@@ -228,26 +205,21 @@ def height(self) -> int:
228205
max_y = max(max_y, v.y1, v.y2)
229206
return 1 + int(max_y)
230207

231-
def force_horizontal_padding_after(self, index: int, padding: Union[int, float]) -> None:
208+
def force_horizontal_padding_after(self, index: int, padding: float) -> None:
232209
"""Change the padding after the given column."""
233210
self.horizontal_padding[index] = padding
234211

235-
def force_vertical_padding_after(self, index: int, padding: Union[int, float]) -> None:
212+
def force_vertical_padding_after(self, index: int, padding: float) -> None:
236213
"""Change the padding after the given row."""
237214
self.vertical_padding[index] = padding
238215

239-
def _transform_coordinates(
240-
self,
241-
func: Callable[
242-
[Union[int, float], Union[int, float]], Tuple[Union[int, float], Union[int, float]]
243-
],
244-
) -> None:
216+
def _transform_coordinates(self, func: Callable[[float, float], Tuple[float, float]]) -> None:
245217
"""Helper method to transformer either row or column coordinates."""
246218

247-
def func_x(x: Union[int, float]) -> Union[int, float]:
219+
def func_x(x: float) -> float:
248220
return func(x, 0)[0]
249221

250-
def func_y(y: Union[int, float]) -> Union[int, float]:
222+
def func_y(y: float) -> float:
251223
return func(0, y)[1]
252224

253225
self.entries = {
@@ -271,19 +243,15 @@ def func_y(y: Union[int, float]) -> Union[int, float]:
271243
def insert_empty_columns(self, x: int, amount: int = 1) -> None:
272244
"""Insert a number of columns after the given column."""
273245

274-
def transform_columns(
275-
column: Union[int, float], row: Union[int, float]
276-
) -> Tuple[Union[int, float], Union[int, float]]:
246+
def transform_columns(column: float, row: float) -> Tuple[float, float]:
277247
return column + (amount if column >= x else 0), row
278248

279249
self._transform_coordinates(transform_columns)
280250

281251
def insert_empty_rows(self, y: int, amount: int = 1) -> None:
282252
"""Insert a number of rows after the given row."""
283253

284-
def transform_rows(
285-
column: Union[int, float], row: Union[int, float]
286-
) -> Tuple[Union[int, float], Union[int, float]]:
254+
def transform_rows(column: float, row: float) -> Tuple[float, float]:
287255
return column, row + (amount if row >= y else 0)
288256

289257
self._transform_coordinates(transform_rows)

cirq-core/cirq/linalg/transformations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ class EntangledStateError(ValueError):
422422

423423

424424
def partial_trace_of_state_vector_as_mixture(
425-
state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8
425+
state_vector: np.ndarray, keep_indices: List[int], *, atol: float = 1e-8
426426
) -> Tuple[Tuple[float, np.ndarray], ...]:
427427
"""Returns a mixture representing a state vector with only some qubits kept.
428428
@@ -481,7 +481,7 @@ def sub_state_vector(
481481
keep_indices: List[int],
482482
*,
483483
default: np.ndarray = RaiseValueErrorIfNotProvided,
484-
atol: Union[int, float] = 1e-6,
484+
atol: float = 1e-6,
485485
) -> np.ndarray:
486486
r"""Attempts to factor a state vector into two parts and return one of them.
487487

cirq-core/cirq/ops/clifford_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def _to_phased_xz_gate(self) -> phased_x_z_gate.PhasedXZGate:
742742
z = -0.5 if x_to_flip else 0.5
743743
return phased_x_z_gate.PhasedXZGate(x_exponent=x, z_exponent=z, axis_phase_exponent=a)
744744

745-
def __pow__(self, exponent: Union[float, int]) -> 'SingleQubitCliffordGate':
745+
def __pow__(self, exponent: float) -> 'SingleQubitCliffordGate':
746746
if int(exponent) == exponent:
747747
# The single qubit Clifford gates are a group of size 24
748748
ret_gate = super().__pow__(int(exponent) % 24)

cirq-core/cirq/ops/dense_pauli_string.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool)
203203
def __pos__(self):
204204
return self
205205

206-
def __pow__(self, power: Union[int, float]) -> Union[NotImplementedType, Self]:
206+
def __pow__(self, power: float) -> Union[NotImplementedType, Self]:
207207
concrete_class = type(self)
208208
if isinstance(power, int):
209209
i_group = [1, +1j, -1, -1j]

cirq-core/cirq/ops/gate_operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
360360
return protocols.qasm(self.gate, args=args, qubits=self.qubits, default=None)
361361

362362
def _equal_up_to_global_phase_(
363-
self, other: Any, atol: Union[int, float] = 1e-8
363+
self, other: Any, atol: float = 1e-8
364364
) -> Union[NotImplementedType, bool]:
365365
if not isinstance(other, type(self)):
366366
return NotImplemented

cirq-core/cirq/ops/pauli_string_phasor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(
6666
pauli_string: ps.PauliString,
6767
qubits: Optional[Sequence['cirq.Qid']] = None,
6868
*,
69-
exponent_neg: Union[int, float, sympy.Expr] = 1,
70-
exponent_pos: Union[int, float, sympy.Expr] = 0,
69+
exponent_neg: 'cirq.TParamVal' = 1,
70+
exponent_pos: 'cirq.TParamVal' = 0,
7171
) -> None:
7272
"""Initializes the operation.
7373
@@ -112,12 +112,12 @@ def gate(self) -> 'cirq.PauliStringPhasorGate':
112112
return cast(PauliStringPhasorGate, self._gate)
113113

114114
@property
115-
def exponent_neg(self) -> Union[int, float, sympy.Expr]:
115+
def exponent_neg(self) -> 'cirq.TParamVal':
116116
"""The negative exponent."""
117117
return self.gate.exponent_neg
118118

119119
@property
120-
def exponent_pos(self) -> Union[int, float, sympy.Expr]:
120+
def exponent_pos(self) -> 'cirq.TParamVal':
121121
"""The positive exponent."""
122122
return self.gate.exponent_pos
123123

@@ -127,7 +127,7 @@ def pauli_string(self) -> 'cirq.PauliString':
127127
return self._pauli_string
128128

129129
@property
130-
def exponent_relative(self) -> Union[int, float, sympy.Expr]:
130+
def exponent_relative(self) -> 'cirq.TParamVal':
131131
"""The relative exponent between negative and positive exponents."""
132132
return self.gate.exponent_relative
133133

@@ -278,8 +278,8 @@ def __init__(
278278
self,
279279
dense_pauli_string: dps.DensePauliString,
280280
*,
281-
exponent_neg: Union[int, float, sympy.Expr] = 1,
282-
exponent_pos: Union[int, float, sympy.Expr] = 0,
281+
exponent_neg: 'cirq.TParamVal' = 1,
282+
exponent_pos: 'cirq.TParamVal' = 0,
283283
) -> None:
284284
"""Initializes the PauliStringPhasorGate.
285285
@@ -309,17 +309,17 @@ def __init__(
309309
self._exponent_pos = value.canonicalize_half_turns(exponent_pos)
310310

311311
@property
312-
def exponent_relative(self) -> Union[int, float, sympy.Expr]:
312+
def exponent_relative(self) -> 'cirq.TParamVal':
313313
"""The relative exponent between negative and positive exponents."""
314314
return value.canonicalize_half_turns(self.exponent_neg - self.exponent_pos)
315315

316316
@property
317-
def exponent_neg(self) -> Union[int, float, sympy.Expr]:
317+
def exponent_neg(self) -> 'cirq.TParamVal':
318318
"""The negative exponent."""
319319
return self._exponent_neg
320320

321321
@property
322-
def exponent_pos(self) -> Union[int, float, sympy.Expr]:
322+
def exponent_pos(self) -> 'cirq.TParamVal':
323323
"""The positive exponent."""
324324
return self._exponent_pos
325325

cirq-core/cirq/ops/pauli_sum_exponential.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Iterator, Tuple, Union, TYPE_CHECKING
14+
from typing import Any, Iterator, Tuple, TYPE_CHECKING
1515

1616
import numpy as np
17-
import sympy
1817

1918
from cirq import linalg, protocols, value, _compat
2019
from cirq.ops import linear_combinations, pauli_string_phasor
@@ -45,7 +44,7 @@ class returns an operation which is equivalent to
4544
def __init__(
4645
self,
4746
pauli_sum_like: 'cirq.PauliSumLike',
48-
exponent: Union[int, float, sympy.Expr] = 1,
47+
exponent: 'cirq.TParamVal' = 1,
4948
atol: float = 1e-8,
5049
):
5150
pauli_sum = linear_combinations.PauliSum.wrap(pauli_sum_like)

cirq-core/cirq/ops/phased_x_z_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> Iterator['cirq.OP_TREE']:
195195
yield ops.X(q) ** self._x_exponent
196196
yield ops.Z(q) ** (self._axis_phase_exponent + self._z_exponent)
197197

198-
def __pow__(self, exponent: Union[float, int]) -> 'PhasedXZGate':
198+
def __pow__(self, exponent: float) -> 'PhasedXZGate':
199199
if exponent == 1:
200200
return self
201201
if exponent == -1:

0 commit comments

Comments
 (0)