Source code for mbrs.timer

  1"""
  2Timer module including some useful global objects.
  3
  4Example:
  5    >>> from mbrs import timer
  6    >>> for hyps, src in zip(hypotheses, sources):
  7    ...     with timer.measure("encode/hypotheses") as t:
  8    ...         h = metric.encode(hyps)
  9    ...     t.set_delta_ncalls(len(hyps))
 10    ...     with timer.measure("encode/source"):
 11    ...         s = metric.encode([src])
 12    ...     with timer.measure("score"):
 13    ...         scores = metric.score(h, s)
 14    >>> res = timer.aggregate().result()  # return the result table
 15"""
 16
 17from __future__ import annotations
 18
 19import contextlib
 20import time
 21from collections import defaultdict
 22from dataclasses import dataclass, field
 23
 24
[docs] 25class Stopwatch: 26 """Stopwatch class to measure the elapsed time. 27 28 Example: 29 >>> timer = Stopwatch() 30 >>> for i in range(10): 31 ... with timer(): 32 ... time.sleep(1) 33 >>> print(f"{timer.elapsed_time:.3f}") 34 10.000 35 36 >>> timer = Stopwatch() 37 >>> for i in range(10): 38 ... with timer() as t: 39 ... time.sleep(2) 40 ... t.set_delta_ncalls(2) 41 >>> print(f"{timer.elapsed_time:.3f}") 42 20.000 43 >>> print(f"{timer.ncalls}") 44 20 45 """ 46 47 def __init__(self) -> None: 48 self.reset() 49
[docs] 50 def reset(self) -> None: 51 """Reset the stopwatch.""" 52 self._acc_time = 0.0 53 self._acc_ncalls = 0 54 self._delta_ncalls = 1
55 56 @contextlib.contextmanager 57 def __call__(self): 58 """Measure the time.""" 59 _acc_time = self._acc_time 60 _acc_ncalls = self._acc_ncalls 61 start = time.perf_counter() 62 try: 63 self.set_delta_ncalls(1) # Set to 1 in the default 64 yield self 65 finally: 66 # Treat nest calling 67 if _acc_time == self._acc_time: 68 self._acc_time = _acc_time + time.perf_counter() - start 69 if _acc_ncalls == self._acc_ncalls: 70 self._acc_ncalls = _acc_ncalls + self._delta_ncalls 71
[docs] 72 def set_delta_ncalls(self, delta: int = 1): 73 """Set delta for counting the number of calls.""" 74 self._delta_ncalls = delta
75 76 @property 77 def elpased_time(self) -> float: 78 """Return the total elapsed time.""" 79 return self._acc_time 80 81 @property 82 def ncalls(self) -> int: 83 """Return the number of calls.""" 84 return self._acc_ncalls
85 86
[docs] 87class StopwatchDict(defaultdict[str, Stopwatch]): 88 """A dictionary of the :class:`Stopwatch` class. 89 90 Example: 91 >>> timers = StopwatchDict() 92 >>> for i in range(10): 93 ... with timers("A"): 94 ... time.sleep(1) 95 >>> for i in range(3): 96 ... with timers("B"): 97 ... time.sleep(1) 98 >>> print(f"{timers.total}") 99 {"A": 10.000, "B": 3.000} 100 """ 101 102 def __init__(self) -> None: 103 super().__init__(Stopwatch) 104
[docs] 105 def reset(self) -> None: 106 """Reset all stopwatches.""" 107 for t in self.values(): 108 t.reset()
109 110 @contextlib.contextmanager 111 def __call__(self, name: str): 112 """Measure the time.""" 113 with self[name]() as timer: 114 try: 115 yield timer 116 finally: 117 pass 118 119 @property 120 def elapsed_time(self) -> dict[str, float]: 121 """Return the total elapsed time.""" 122 return {k: v.elpased_time for k, v in self.items()} 123 124 @property 125 def ncalls(self) -> dict[str, int]: 126 """Return the number of calls.""" 127 return {k: v.ncalls for k, v in self.items()}
128 129 130measure = StopwatchDict() 131 132
[docs] 133@dataclass 134class ProfileTree: 135 elapsed_time: float = -1.0 136 ncalls: int = -1 137 children: dict[str, ProfileTree] = field(default_factory=dict) 138 139 @property 140 def is_leaf(self) -> bool: 141 """Return whether the node is leaf or not.""" 142 return len(self.children) == 0 143
[docs] 144 def aggregate(self): 145 if self.is_leaf: 146 if self.elapsed_time < 0.0: 147 raise RuntimeError("Missing elapsed_time") 148 if self.ncalls < 0: 149 raise RuntimeError("Missing ncalls") 150 else: 151 for child in self.children.values(): 152 child.aggregate() 153 if self.elapsed_time < 0.0: 154 self.elapsed_time = 0.0 155 for child in self.children.values(): 156 self.elapsed_time += child.elapsed_time 157 if self.ncalls < 0: 158 self.ncalls = 0 159 for child in self.children.values(): 160 self.ncalls += child.ncalls
161
[docs] 162 @classmethod 163 def build(cls, timers: StopwatchDict, separetor: str = "/"): 164 root = cls() 165 for name, timer in timers.items(): 166 prefix = name.split(separetor) 167 node = root 168 for path in prefix: 169 if path not in node.children: 170 node.children[path] = ProfileTree() 171 node = node.children[path] 172 node.elapsed_time = timer.elpased_time 173 node.ncalls = timer.ncalls 174 root.aggregate() 175 return root
176
[docs] 177 def result(self, nsentences: int = -1) -> list[dict[str, str | int | float]]: 178 def _result(name: str, node: ProfileTree) -> list[dict[str, str | int | float]]: 179 stat = { 180 "name": name.strip("/"), 181 "acctime": node.elapsed_time, 182 "acccalls": node.ncalls, 183 "ms/call": node.elapsed_time * 1000 / node.ncalls, 184 } 185 if nsentences > 0: 186 stat["ms/sentence"] = node.elapsed_time * 1000 / nsentences 187 stat["calls/sentence"] = node.ncalls / nsentences 188 res = [stat] 189 for path, child in node.children.items(): 190 res += _result(name + "/" + path, child) 191 return res 192 193 res = _result("", self) 194 return res[1:] # Remove the root node
195 196
[docs] 197def aggregate() -> ProfileTree: 198 """Aggregate the timers. 199 200 Returns: 201 ProfileTree: The root of the profile tree. 202 """ 203 return ProfileTree.build(measure)