Source code for mbrs.timer
1"""
2Timer module including some useful global objects.
3
4Example:
5 >>> from mbrs import timer
6 >>> for hyps, src in zip(hypotheses, sources):
7 ... with timer.measure("encode/hypotheses") as t:
8 ... h = metric.encode(hyps)
9 ... t.set_delta_ncalls(len(hyps))
10 ... with timer.measure("encode/source"):
11 ... s = metric.encode([src])
12 ... with timer.measure("score"):
13 ... scores = metric.score(h, s)
14 >>> res = timer.aggregate().result() # return the result table
15"""
16
17from __future__ import annotations
18
19import contextlib
20import time
21from collections import defaultdict
22from dataclasses import dataclass, field
23
24
[docs]
25class Stopwatch:
26 """Stopwatch class to measure the elapsed time.
27
28 Example:
29 >>> timer = Stopwatch()
30 >>> for i in range(10):
31 ... with timer():
32 ... time.sleep(1)
33 >>> print(f"{timer.elapsed_time:.3f}")
34 10.000
35
36 >>> timer = Stopwatch()
37 >>> for i in range(10):
38 ... with timer() as t:
39 ... time.sleep(2)
40 ... t.set_delta_ncalls(2)
41 >>> print(f"{timer.elapsed_time:.3f}")
42 20.000
43 >>> print(f"{timer.ncalls}")
44 20
45 """
46
47 def __init__(self) -> None:
48 self.reset()
49
[docs]
50 def reset(self) -> None:
51 """Reset the stopwatch."""
52 self._acc_time = 0.0
53 self._acc_ncalls = 0
54 self._delta_ncalls = 1
55
56 @contextlib.contextmanager
57 def __call__(self):
58 """Measure the time."""
59 _acc_time = self._acc_time
60 _acc_ncalls = self._acc_ncalls
61 start = time.perf_counter()
62 try:
63 self.set_delta_ncalls(1) # Set to 1 in the default
64 yield self
65 finally:
66 # Treat nest calling
67 if _acc_time == self._acc_time:
68 self._acc_time = _acc_time + time.perf_counter() - start
69 if _acc_ncalls == self._acc_ncalls:
70 self._acc_ncalls = _acc_ncalls + self._delta_ncalls
71
[docs]
72 def set_delta_ncalls(self, delta: int = 1):
73 """Set delta for counting the number of calls."""
74 self._delta_ncalls = delta
75
76 @property
77 def elpased_time(self) -> float:
78 """Return the total elapsed time."""
79 return self._acc_time
80
81 @property
82 def ncalls(self) -> int:
83 """Return the number of calls."""
84 return self._acc_ncalls
85
86
[docs]
87class StopwatchDict(defaultdict[str, Stopwatch]):
88 """A dictionary of the :class:`Stopwatch` class.
89
90 Example:
91 >>> timers = StopwatchDict()
92 >>> for i in range(10):
93 ... with timers("A"):
94 ... time.sleep(1)
95 >>> for i in range(3):
96 ... with timers("B"):
97 ... time.sleep(1)
98 >>> print(f"{timers.total}")
99 {"A": 10.000, "B": 3.000}
100 """
101
102 def __init__(self) -> None:
103 super().__init__(Stopwatch)
104
[docs]
105 def reset(self) -> None:
106 """Reset all stopwatches."""
107 for t in self.values():
108 t.reset()
109
110 @contextlib.contextmanager
111 def __call__(self, name: str):
112 """Measure the time."""
113 with self[name]() as timer:
114 try:
115 yield timer
116 finally:
117 pass
118
119 @property
120 def elapsed_time(self) -> dict[str, float]:
121 """Return the total elapsed time."""
122 return {k: v.elpased_time for k, v in self.items()}
123
124 @property
125 def ncalls(self) -> dict[str, int]:
126 """Return the number of calls."""
127 return {k: v.ncalls for k, v in self.items()}
128
129
130measure = StopwatchDict()
131
132
[docs]
133@dataclass
134class ProfileTree:
135 elapsed_time: float = -1.0
136 ncalls: int = -1
137 children: dict[str, ProfileTree] = field(default_factory=dict)
138
139 @property
140 def is_leaf(self) -> bool:
141 """Return whether the node is leaf or not."""
142 return len(self.children) == 0
143
[docs]
144 def aggregate(self):
145 if self.is_leaf:
146 if self.elapsed_time < 0.0:
147 raise RuntimeError("Missing elapsed_time")
148 if self.ncalls < 0:
149 raise RuntimeError("Missing ncalls")
150 else:
151 for child in self.children.values():
152 child.aggregate()
153 if self.elapsed_time < 0.0:
154 self.elapsed_time = 0.0
155 for child in self.children.values():
156 self.elapsed_time += child.elapsed_time
157 if self.ncalls < 0:
158 self.ncalls = 0
159 for child in self.children.values():
160 self.ncalls += child.ncalls
161
[docs]
162 @classmethod
163 def build(cls, timers: StopwatchDict, separetor: str = "/"):
164 root = cls()
165 for name, timer in timers.items():
166 prefix = name.split(separetor)
167 node = root
168 for path in prefix:
169 if path not in node.children:
170 node.children[path] = ProfileTree()
171 node = node.children[path]
172 node.elapsed_time = timer.elpased_time
173 node.ncalls = timer.ncalls
174 root.aggregate()
175 return root
176
[docs]
177 def result(self, nsentences: int = -1) -> list[dict[str, str | int | float]]:
178 def _result(name: str, node: ProfileTree) -> list[dict[str, str | int | float]]:
179 stat = {
180 "name": name.strip("/"),
181 "acctime": node.elapsed_time,
182 "acccalls": node.ncalls,
183 "ms/call": node.elapsed_time * 1000 / node.ncalls,
184 }
185 if nsentences > 0:
186 stat["ms/sentence"] = node.elapsed_time * 1000 / nsentences
187 stat["calls/sentence"] = node.ncalls / nsentences
188 res = [stat]
189 for path, child in node.children.items():
190 res += _result(name + "/" + path, child)
191 return res
192
193 res = _result("", self)
194 return res[1:] # Remove the root node
195
196
[docs]
197def aggregate() -> ProfileTree:
198 """Aggregate the timers.
199
200 Returns:
201 ProfileTree: The root of the profile tree.
202 """
203 return ProfileTree.build(measure)