1"""
2Module provides helper functions.
3"""
4import os
5import math
6import tempfile
7import zipfile
8from pathlib import Path
9import re
10import requests
11import time
12import multiprocessing as mp
13from deprecated import deprecated
14from tqdm import tqdm
15import numpy as np
16import matplotlib
17import matplotlib.pyplot as plt
18from epanet_plus import EpanetConstants
19
20
21AREA_UNIT_FT2 = 1
22AREA_UNIT_M2 = 2
23AREA_UNIT_CM2 = 3
24MASS_UNIT_MG = 4
25MASS_UNIT_UG = 5
26MASS_UNIT_MOL = 6
27MASS_UNIT_MMOL = 7
28TIME_UNIT_HRS = 8
29MASS_UNIT_CUSTOM = 9
30
31
[docs]
32def pressureunit_to_str(unit_id: int) -> str:
33 """
34 Converts a given pressure unit ID to the corresponding description.
35
36 Parameters
37 ----------
38 unit_id : `int`
39 Pressure unit ID.
40
41 Must be one of the following EPANET constants:
42
43 - EN_PSI = 0 (Pounds per square inch)
44 - EN_KPA = 1 (Kilopascals)
45 - EN_METERS = 2 (Meters)
46 - EN_BAR = 3 (Bar)
47 - EN_FEET = 4 (Feet)
48
49 Returns
50 -------
51 `str`
52 Pressure unit description.
53 """
54 if unit_id is None:
55 return ""
56 elif unit_id == EpanetConstants.EN_PSI:
57 return "psi"
58 elif unit_id == EpanetConstants.EN_KPA:
59 return "kilopascal"
60 elif unit_id == EpanetConstants.EN_METERS:
61 return "meter"
62 elif unit_id == EpanetConstants.EN_BAR:
63 return "bar"
64 elif unit_id == EpanetConstants.EN_FEET:
65 return "feet"
66 else:
67 raise ValueError(f"Unknown unit ID '{unit_id}'")
68
69
[docs]
70def flowunit_to_str(unit_id: int) -> str:
71 """
72 Converts a given flow unit ID to the corresponding description.
73
74 Parameters
75 ----------
76 unit_id : `int`
77 Flow unit ID.
78
79 Must be one of the following EPANET constants:
80
81 - EN_CFS = 0 (cubic foot/sec)
82 - EN_GPM = 1 (gal/min)
83 - EN_MGD = 2 (Million gal/day)
84 - EN_IMGD = 3 (Imperial MGD)
85 - EN_AFD = 4 (ac-foot/day)
86 - EN_LPS = 5 (liter/sec)
87 - EN_LPM = 6 (liter/min)
88 - EN_MLD = 7 (Megaliter/day)
89 - EN_CMH = 8 (cubic meter/hr)
90 - EN_CMD = 9 (cubic meter/day)
91 - EN_CMS = 10 (cubic meter/sec)
92
93 Returns
94 -------
95 `str`
96 Flow unit description.
97 """
98 if unit_id is None:
99 return ""
100 elif unit_id == EpanetConstants.EN_CFS:
101 return "cubic foot/sec"
102 elif unit_id == EpanetConstants.EN_GPM:
103 return "gal/min"
104 elif unit_id == EpanetConstants.EN_MGD:
105 return "Million gal/day"
106 elif unit_id == EpanetConstants.EN_IMGD:
107 return "Imperial MGD"
108 elif unit_id == EpanetConstants.EN_AFD:
109 return "ac-foot/day"
110 elif unit_id == EpanetConstants.EN_LPS:
111 return "liter/sec"
112 elif unit_id == EpanetConstants.EN_LPM:
113 return "liter/min"
114 elif unit_id == EpanetConstants.EN_MLD:
115 return "Megaliter/day"
116 elif unit_id == EpanetConstants.EN_CMH:
117 return "cubic meter/hr"
118 elif unit_id == EpanetConstants.EN_CMD:
119 return "cubic meter/day"
120 elif unit_id == EpanetConstants.EN_CMS:
121 return "cubic meter/sec"
122 else:
123 raise ValueError(f"Unknown unit ID '{unit_id}'")
124
125
[docs]
126def areaunit_to_id(unit_desc: str) -> int:
127 """
128 Converts a given area units string to the corresponding ID.
129
130 Parameters
131 ----------
132 unit_desc : `str`
133 Area units string.
134
135 Returns
136 -------
137 `int`
138 Corresponding area unit ID.
139 """
140 return {"FT2": AREA_UNIT_FT2,
141 "M2": AREA_UNIT_M2,
142 "CM2": AREA_UNIT_CM2}[unit_desc]
143
144
[docs]
145def massunit_to_id(unit_desc: str) -> int:
146 """
147 Converts a given mass units string to the corresponding ID.
148
149 Parameters
150 ----------
151 unit_desc : `str`
152 Mass units string.
153
154 Returns
155 -------
156 `int`
157 Corresponding mass unit ID.
158 """
159 mass_unit_dict = {"MG": MASS_UNIT_MG,
160 "UG": MASS_UNIT_UG,
161 "MOL": MASS_UNIT_MOL,
162 "MMOL": MASS_UNIT_MMOL}
163
164 if unit_desc in mass_unit_dict:
165 return mass_unit_dict[unit_desc]
166 else:
167 return MASS_UNIT_CUSTOM
168
169
[docs]
170def qualityunit_to_id(unit_desc: str) -> int:
171 """
172 Converts a given measurement unit description to the corresponding mass unit ID.
173
174 Parameters
175 ----------
176 unit_desc : `str`
177 Mass unit.
178
179 Returns
180 -------
181 `int`
182 Mass unit ID.
183
184 Will be either None (if no water quality analysis was set up) or
185 one of the following constants:
186
187 - MASS_UNIT_MG = 4 (mg/L)
188 - MASS_UNIT_UG = 5 (ug/L)
189 - TIME_UNIT_HRS = 8 (hrs)
190 """
191 if unit_desc == "mg/L":
192 return MASS_UNIT_MG
193 elif unit_desc == "ug/L":
194 return MASS_UNIT_UG
195 elif unit_desc == "hrs":
196 return TIME_UNIT_HRS
197 else:
198 return None
199
200
[docs]
201def massunit_to_str(unit_id: int) -> str:
202 """
203 Converts a given mass unit ID to the corresponding description.
204
205 Parameters
206 ----------
207 unit_id : `int`
208 ID of the mass unit.
209
210 Must be one of the following constant:
211
212 - MASS_UNIT_MG = 4
213 - MASS_UNIT_UG = 5
214 - MASS_UNIT_MOL = 6
215 - MASS_UNIT_MMOL = 7
216 - MASS_UNIT_CUSTOM = 9
217
218 Returns
219 -------
220 `str`
221 Mass unit description.
222 """
223 if unit_id is None:
224 return ""
225 elif unit_id == MASS_UNIT_MG:
226 return "MG"
227 elif unit_id == MASS_UNIT_UG:
228 return "UG"
229 elif unit_id == MASS_UNIT_MOL:
230 return "MOL"
231 elif unit_id == MASS_UNIT_MMOL:
232 return "MMOL"
233 elif unit_id == MASS_UNIT_CUSTOM:
234 return "CUSTOM UNIT"
235 else:
236 raise ValueError(f"Unknown mass unit ID '{unit_id}'")
237
238
[docs]
239def qualityunit_to_str(unit_id: int) -> str:
240 """
241 Converts a given quality measurement unit ID to the corresponding description.
242
243 Parameters
244 ----------
245 unit_id : `int`
246 ID of the quality unit.
247
248 Must be one of the following constants:
249
250 - MASS_UNIT_MG = 4 (mg/L)
251 - MASS_UNIT_UG = 5 (ug/L)
252 - TIME_UNIT_HRS = 8 (hrs)
253
254 Returns
255 -------
256 `str`
257 Mass unit description.
258 """
259 if unit_id is None:
260 return ""
261 elif unit_id == MASS_UNIT_MG:
262 return "mg/L"
263 elif unit_id == MASS_UNIT_UG:
264 return "ug/L"
265 elif unit_id == TIME_UNIT_HRS:
266 return "hrs"
267 else:
268 raise ValueError(f"Unknown unit ID '{unit_id}'")
269
270
[docs]
271def areaunit_to_str(unit_id: int) -> str:
272 """
273 Converts a given area measurement unit ID to the corresponding description.
274
275 Parameters
276 ----------
277 unit_id : `int`
278 ID of the area unit.
279
280 Must be one of the following constants:
281
282 - AREA_UNIT_FT2 = 1
283 - AREA_UNIT_M2 = 2
284 - AREA_UNIT_CM2 = 3
285
286 Returns
287 -------
288 `str`
289 Area unit description.
290 """
291 if unit_id is None:
292 return None
293 elif unit_id == AREA_UNIT_FT2:
294 return "FT2"
295 elif unit_id == AREA_UNIT_M2:
296 return "M2"
297 elif unit_id == AREA_UNIT_CM2:
298 return "CM2"
299 else:
300 raise ValueError(f"Unknown unit ID '{unit_id}'")
301
302
[docs]
303def is_flowunit_simetric(unit_id: int) -> bool:
304 """
305 Checks if a given flow unit belongs to SI metric units.
306
307 Parameters
308 ----------
309 unit_id : `int`
310 ID of the flow unit.
311
312 Must be one of the following EPANET constants:
313
314 - EN_CFS = 0 (cubic foot/sec)
315 - EN_GPM = 1 (gal/min)
316 - EN_MGD = 2 (Million gal/day)
317 - EN_IMGD = 3 (Imperial MGD)
318 - EN_AFD = 4 (ac-foot/day)
319 - EN_LPS = 5 (liter/sec)
320 - EN_LPM = 6 (liter/min)
321 - EN_MLD = 7 (Megaliter/day)
322 - EN_CMH = 8 (cubic meter/hr)
323 - EN_CMD = 9 (cubic meter/day)
324 - EN_CMS = 10 (cubic meter/sec)
325
326 Returns
327 -------
328 `bool`
329 True if the fiven unit is a SI metric unit, False otherwise.
330 """
331 return unit_id in [EpanetConstants.EN_LPS, EpanetConstants.EN_LPM, EpanetConstants.EN_MLD,
332 EpanetConstants.EN_CMH, EpanetConstants.EN_CMD, EpanetConstants.EN_CMS]
333
334
335def _get_pressure_convert_factor(new_unit_id: int, old_unit: int) -> float:
336 if new_unit_id == old_unit:
337 return 1.
338
339 if new_unit_id == EpanetConstants.EN_BAR:
340 if old_unit == EpanetConstants.EN_PSI:
341 return .0689476
342 elif old_unit == EpanetConstants.EN_METERS:
343 return .09804139432
344 elif old_unit == EpanetConstants.EN_FEET:
345 return .029883016988736
346 elif old_unit == EpanetConstants.EN_KPA:
347 return .01
348
349 elif new_unit_id == EpanetConstants.EN_KPA:
350 if old_unit == EpanetConstants.EN_PSI:
351 return 6.894744825494
352 elif old_unit == EpanetConstants.EN_METERS:
353 return 9.804139432
354 elif old_unit == EpanetConstants.EN_FEET:
355 return 2.9890669
356 elif old_unit == EpanetConstants.EN_BAR:
357 return 100.
358
359 elif new_unit_id == EpanetConstants.EN_FEET:
360 if old_unit == EpanetConstants.EN_PSI:
361 return 2.3072493927233
362 elif old_unit == EpanetConstants.EN_METERS:
363 return 3.2808398950131
364 elif old_unit == EpanetConstants.EN_KPA:
365 return .33455256555148
366 elif old_unit == EpanetConstants.EN_BAR:
367 return 33.455256555148
368
369 elif new_unit_id == EpanetConstants.EN_PSI:
370 if old_unit == EpanetConstants.EN_KPA:
371 return .14503773773020923
372 elif old_unit == EpanetConstants.EN_METERS:
373 return 1.4219702063247
374 elif old_unit == EpanetConstants.EN_FEET:
375 return .43341651888775
376 elif old_unit == EpanetConstants.EN_BAR:
377 return 14.5038
378
379 elif new_unit_id == EpanetConstants.EN_METERS:
380 if old_unit == EpanetConstants.EN_PSI:
381 return .70325
382 elif old_unit == EpanetConstants.EN_KPA:
383 return .10199773339984
384 elif old_unit == EpanetConstants.EN_FEET:
385 return .3048
386 elif old_unit == EpanetConstants.EN_BAR:
387 return 10.199773339984
388
389 else:
390 raise ValueError("Invalid 'new_unit_id'")
391
392
393def _get_flow_convert_factor(new_unit_id: int, old_unit: int) -> float:
394 if new_unit_id == old_unit:
395 return 1.
396
397 if new_unit_id == EpanetConstants.EN_CFS:
398 if old_unit == EpanetConstants.EN_GPM:
399 return .0022280093
400 elif old_unit == EpanetConstants.EN_MGD:
401 return 1.5472286523
402 elif old_unit == EpanetConstants.EN_IMGD:
403 return 1.8581441347
404 elif old_unit == EpanetConstants.EN_AFD:
405 return .5041666667
406 elif old_unit == EpanetConstants.EN_LPS:
407 return .0353146667
408 elif old_unit == EpanetConstants.EN_LPM:
409 return .0005885778
410 elif old_unit == EpanetConstants.EN_MLD:
411 return .40873456853575
412 elif old_unit == EpanetConstants.EN_CMH:
413 return .0098096296
414 elif old_unit == EpanetConstants.EN_CMD:
415 return .0004087346
416
417 elif new_unit_id == EpanetConstants.EN_GPM:
418 if old_unit == EpanetConstants.EN_CFS:
419 return 448.8325660485
420 elif old_unit == EpanetConstants.EN_MGD:
421 return 694.44444444
422 elif old_unit == EpanetConstants.EN_IMGD:
423 return 833.99300382
424 elif old_unit == EpanetConstants.EN_AFD:
425 return 226.28571429
426 elif old_unit == EpanetConstants.EN_LPS:
427 return 15.850323141
428 elif old_unit == EpanetConstants.EN_LPM:
429 return .2641720524
430 elif old_unit == EpanetConstants.EN_MLD:
431 return 183.4528141376
432 elif old_unit == EpanetConstants.EN_CMH:
433 return 4.4028675393
434 elif old_unit == EpanetConstants.EN_CMD:
435 return .1834528141
436
437 elif new_unit_id == EpanetConstants.EN_MGD:
438 if old_unit == EpanetConstants.EN_CFS:
439 return .6463168831
440 elif old_unit == EpanetConstants.EN_GPM:
441 return .00144
442 elif old_unit == EpanetConstants.EN_IMGD:
443 return 1.2009499255
444 elif old_unit == EpanetConstants.EN_AFD:
445 return 0.3258514286
446 elif old_unit == EpanetConstants.EN_LPS:
447 return .0228244653
448 elif old_unit == EpanetConstants.EN_LPM:
449 return .0003804078
450 elif old_unit == EpanetConstants.EN_MLD:
451 return .26417205124156
452 elif old_unit == EpanetConstants.EN_CMH:
453 return .0063401293
454 elif old_unit == EpanetConstants.EN_CMD:
455 return .0002641721
456
457 elif new_unit_id == EpanetConstants.EN_IMGD:
458 if old_unit == EpanetConstants.EN_CFS:
459 return .5381713837
460 elif old_unit == EpanetConstants.EN_MGD:
461 return .8326741846
462 elif old_unit == EpanetConstants.EN_GPM:
463 return .0011990508
464 elif old_unit == EpanetConstants.EN_AFD:
465 return .2713280726
466 elif old_unit == EpanetConstants.EN_LPS:
467 return .0190053431
468 elif old_unit == EpanetConstants.EN_LPM:
469 return .0003167557
470 elif old_unit == EpanetConstants.EN_MLD:
471 return .21996924829908776
472 elif old_unit == EpanetConstants.EN_CMH:
473 return .005279262
474 elif old_unit == EpanetConstants.EN_CMD:
475 return .0002199692
476
477 elif new_unit_id == EpanetConstants.EN_AFD:
478 if old_unit == EpanetConstants.EN_CFS:
479 return 1.9834710744
480 elif old_unit == EpanetConstants.EN_MGD:
481 return 3.0688832772
482 elif old_unit == EpanetConstants.EN_GPM:
483 return .0044191919
484 elif old_unit == EpanetConstants.EN_IMGD:
485 return 3.6855751432
486 elif old_unit == EpanetConstants.EN_LPS:
487 return .0700456199
488 elif old_unit == EpanetConstants.EN_LPM:
489 return .001167427
490 elif old_unit == EpanetConstants.EN_MLD:
491 return .81070995093708
492 elif old_unit == EpanetConstants.EN_CMH:
493 return .0194571167
494 elif old_unit == EpanetConstants.EN_CMD:
495 return .0008107132
496
497 elif new_unit_id == EpanetConstants.EN_LPS:
498 if old_unit == EpanetConstants.EN_CFS:
499 return 28.316846592
500 elif old_unit == EpanetConstants.EN_MGD:
501 return 43.812636389
502 elif old_unit == EpanetConstants.EN_IMGD:
503 return 52.616782407
504 elif old_unit == EpanetConstants.EN_GPM:
505 return .0630901964
506 elif old_unit == EpanetConstants.EN_AFD:
507 return 14.276410157
508 elif old_unit == EpanetConstants.EN_LPM:
509 return .0166666667
510 elif old_unit == EpanetConstants.EN_MLD:
511 return 11.574074074074
512 elif old_unit == EpanetConstants.EN_CMH:
513 return .2777777778
514 elif old_unit == EpanetConstants.EN_CMD:
515 return .0115740741
516
517 elif new_unit_id == EpanetConstants.EN_LPM:
518 if old_unit == EpanetConstants.EN_CFS:
519 return 1699.0107955
520 elif old_unit == EpanetConstants.EN_MGD:
521 return 2628.7581833
522 elif old_unit == EpanetConstants.EN_IMGD:
523 return 3157.0069444
524 elif old_unit == EpanetConstants.EN_AFD:
525 return 856.58460941
526 elif old_unit == EpanetConstants.EN_LPS:
527 return 60
528 elif old_unit == EpanetConstants.EN_GPM:
529 return 3.785411784
530 elif old_unit == EpanetConstants.EN_MLD:
531 return 694.44444444443
532 elif old_unit == EpanetConstants.EN_CMH:
533 return 16.666666667
534 elif old_unit == EpanetConstants.EN_CMD:
535 return 0.6944444444
536
537 elif new_unit_id == EpanetConstants.EN_MLD:
538 if old_unit == EpanetConstants.EN_CFS:
539 return 2.4465755456688
540 elif old_unit == EpanetConstants.EN_MGD:
541 return 3.7854117999999777
542 elif old_unit == EpanetConstants.EN_IMGD:
543 return 4.54609
544 elif old_unit == EpanetConstants.EN_AFD:
545 return 1.2334867714947
546 elif old_unit == EpanetConstants.EN_LPS:
547 return .0864
548 elif old_unit == EpanetConstants.EN_LPM:
549 return .00144
550 elif old_unit == EpanetConstants.EN_GPM:
551 return .00545099296896
552 elif old_unit == EpanetConstants.EN_CMH:
553 return .024
554 elif old_unit == EpanetConstants.EN_CMD:
555 return .00099999999999999
556
557 elif new_unit_id == EpanetConstants.EN_CMH:
558 if old_unit == EpanetConstants.EN_CFS:
559 return 101.94064773
560 elif old_unit == EpanetConstants.EN_MGD:
561 return 157.725491
562 elif old_unit == EpanetConstants.EN_IMGD:
563 return 189.42041667
564 elif old_unit == EpanetConstants.EN_AFD:
565 return 51.395076564
566 elif old_unit == EpanetConstants.EN_LPS:
567 return 3.6
568 elif old_unit == EpanetConstants.EN_LPM:
569 return .06
570 elif old_unit == EpanetConstants.EN_MLD:
571 return 41.666666666666
572 elif old_unit == EpanetConstants.EN_GPM:
573 return .227124707
574 elif old_unit == EpanetConstants.EN_CMD:
575 return 0.0416666667
576
577 elif new_unit_id == EpanetConstants.EN_CMD:
578 if old_unit == EpanetConstants.EN_CFS:
579 return 2446.5755455
580 elif old_unit == EpanetConstants.EN_MGD:
581 return 3785.411784
582 elif old_unit == EpanetConstants.EN_IMGD:
583 return 4546.09
584 elif old_unit == EpanetConstants.EN_AFD:
585 return 1233.4818375
586 elif old_unit == EpanetConstants.EN_LPS:
587 return 86.4
588 elif old_unit == EpanetConstants.EN_LPM:
589 return 1.44
590 elif old_unit == EpanetConstants.EN_MLD:
591 return 1000.
592 elif old_unit == EpanetConstants.EN_CMH:
593 return 24
594 elif old_unit == EpanetConstants.EN_GPM:
595 return 5.450992969
596
597
[docs]
598def time_points_to_one_hot_encoding(time_points: list[int], total_length: int) -> list[int]:
599 """
600 Converts a list of time points into a one-hot-encoding.
601
602 Parameters
603 ----------
604 time_points : `list[int]`
605 Time points to be one-hot-encoded.
606 total_length : `int`
607 Length of final one-hot-encoding.
608
609 Returns
610 -------
611 `list[int]`
612 One-hot-encoded time points.
613 """
614 results = [0] * total_length
615
616 for t in time_points:
617 results[t] = 1
618
619 return results
620
621
[docs]
622def volume_to_level(tank_volume: float, tank_diameter: float) -> float:
623 """
624 Computes the water level in a tank containing a given volume of water.
625
626 Parameters
627 ----------
628 tank_volume : `float`
629 Water volume in the tank.
630 tank_diameter : `float`
631 Diameter of the tank.
632
633 Returns
634 -------
635 `float`
636 Water level in tank.
637 """
638 if not isinstance(tank_volume, float):
639 raise TypeError("'tank_volume' must be an instace of 'float' " +
640 f"but not of '{type(tank_volume)}'")
641 if tank_volume < 0:
642 raise ValueError("'tank_volume' can not be negative")
643 if not isinstance(tank_diameter, float):
644 raise TypeError("'tank_diameter' must be an instace of 'float' " +
645 f"but not of '{type(tank_diameter)}'")
646 if tank_diameter <= 0:
647 raise ValueError("'tank_diameter' must be greater than zero")
648
649 return (4. / (math.pow(tank_diameter, 2) * math.pi)) * tank_volume
650
651
[docs]
652def level_to_volume(tank_level: float, tank_diameter: float) -> float:
653 """
654 Computes the volume of water in a tank given the water level in the tank.
655
656 Parameters
657 ----------
658 tank_level : `float`
659 Water level in the tank.
660 tank_diameter : `float`
661 Diameter of the tank.
662
663 Returns
664 -------
665 `float`
666 Water volume in tank.
667 """
668 if not isinstance(tank_level, float):
669 raise TypeError("'tank_level' must be an instace of 'float' " +
670 f"but not of '{type(tank_level)}'")
671 if tank_level < 0:
672 raise ValueError("'tank_level' can not be negative")
673 if not isinstance(tank_diameter, float):
674 raise TypeError("'tank_diameter' must be an instace of 'float' " +
675 f"but not of '{type(tank_diameter)}'")
676 if tank_diameter <= 0:
677 raise ValueError("'tank_diameter' must be greater than zero")
678
679 return tank_level * math.pow(0.5 * tank_diameter, 2) * math.pi
680
681
[docs]
682def plot_timeseries_data(data: np.ndarray, labels: list[str] = None, x_axis_label: str = None,
683 y_axis_label: str = None, y_ticks: tuple[list[float], list[str]] = None,
684 show: bool = True, save_to_file: str = None,
685 ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes:
686 """
687 Plots a single or multiple time series.
688
689 Parameters
690 ----------
691 data : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
692 Time series data -- each row in `data` corresponds to a complete time series.
693 labels : `list[str]`, optional
694 Labels for each time series in `data`.
695 If None, no labels are shown.
696
697 The default is None.
698 x_axis_label : `str`, optional
699 X axis label.
700
701 The default is None.
702 y_axis_label : `str`, optional
703 Y axis label.
704
705 The default is None.
706 y_ticks: `(list[float], list[str])`, optional
707 Tuple of ticks (numbers) and labels (strings) for the y-axis.
708
709 The default is None.
710 show : `bool`, optional
711 If True, the plot/figure is shown in a window.
712
713 Only considered when 'ax' is None.
714
715 The default is True.
716 save_to_file : `str`, optional
717 File to which the plot is saved.
718
719 If specified, 'show' must be set to False --
720 i.e. a plot can not be shown and saved to a file at the same time!
721
722 The default is None.
723 ax : `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_, optional
724 If not None, 'ax' is used for plotting.
725
726 The default is None.
727
728 Returns
729 -------
730 `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_
731 Plot.
732 """
733 if not isinstance(data, np.ndarray):
734 raise TypeError(f"'data' must be an instance of 'numpy.ndarray' but not of '{type(data)}'")
735 if len(data.shape) != 2:
736 raise ValueError("'data' must be a 2d array where each row corresponds to a time series " +
737 "-- use '.reshape(1, -1)' in case of single time series")
738 if labels is not None:
739 if not isinstance(labels, list) or not all(isinstance(label, str) for label in labels):
740 raise TypeError("'labels' must be a instance of 'list[str]'")
741 if x_axis_label is not None:
742 if not isinstance(x_axis_label, str):
743 raise TypeError("'x_axis_label' must be an instance of 'str' " +
744 f"but not of '{type(x_axis_label)}'")
745 if y_axis_label is not None:
746 if not isinstance(y_axis_label, str):
747 raise TypeError("'y_axis_label' must be an instance of 'str' " +
748 f"but not of '{type(y_axis_label)}'")
749 if y_ticks is not None:
750 if len(y_ticks) != 2:
751 raise ValueError("'y_ticks' must be a tuple ticks (numbers) and labels (strings)")
752 if not isinstance(show, bool):
753 raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'")
754 if save_to_file is not None:
755 if show is True:
756 raise ValueError("'show' must be False if 'save_to_file' is set")
757
758 if not isinstance(save_to_file, str):
759 raise TypeError("'save_to_file' must be an instance of 'str' but not of " +
760 f"'{type(save_to_file)}'")
761 if ax is not None:
762 if not isinstance(ax, matplotlib.axes.Axes):
763 raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" +
764 f"but not of '{type(ax)}'")
765
766 fig = None
767 if ax is None:
768 fig, ax = plt.subplots()
769
770 labels = labels if labels is not None else [None] * data.shape[0]
771
772 for i in range(data.shape[0]):
773 ax.plot(data[i, :], ".-", label=labels[i])
774
775 if not any(label is None for label in labels):
776 ax.legend()
777
778 if x_axis_label is not None:
779 ax.set_xlabel(x_axis_label)
780 if y_axis_label is not None:
781 ax.set_ylabel(y_axis_label)
782 if y_ticks is not None:
783 yticks_pos, yticks_labels = y_ticks
784 ax.set_yticks(yticks_pos, labels=yticks_labels)
785
786 if show is True and fig is not None:
787 plt.show()
788 if save_to_file is not None:
789 folder_path = str(Path(save_to_file).parent.absolute())
790 create_path_if_not_exist(folder_path)
791
792 if fig is None:
793 plt.savefig(save_to_file, bbox_inches='tight')
794 else:
795 fig.savefig(save_to_file, bbox_inches='tight')
796
797 return ax
798
799
[docs]
800def plot_timeseries_prediction(y: np.ndarray, y_pred: np.ndarray,
801 confidence_interval: np.ndarray = None,
802 x_axis_label: str = None, y_axis_label: str = None,
803 y_ticks: tuple[list[float], list[str]] = None,
804 show: bool = True, save_to_file: str = None,
805 ax: matplotlib.axes.Axes = None
806 ) -> matplotlib.axes.Axes:
807 """
808 Plots the prediction (e.g. forecast) of *single* time series together with the
809 ground truth time series. In addition, confidence intervals can be plotted as well.
810
811 Parameters
812 ----------
813 y : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
814 Ground truth values.
815 y_pred : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
816 Predicted values.
817 confidence_interval : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional
818 Confidence interval (upper and lower value) for each prediction in `y_pred`.
819 If not None, the confidence interval is plotted as well.
820
821 The default is None.
822 x_axis_label : `str`, optional
823 X axis label.
824
825 The default is None.
826 y_axis_label : `str`, optional
827 Y axis label.
828
829 The default is None.
830 y_ticks: `(list[float], list[str])`, optional
831 Tuple of ticks (numbers) and labels (strings) for the y-axis.
832
833 The default is None.
834 show : `bool`, optional
835 If True, the plot/figure is shown in a window.
836
837 Only considered when 'ax' is None.
838
839 The default is True.
840 save_to_file : `str`, optional
841 File to which the plot is saved.
842
843 If specified, 'show' must be set to False --
844 i.e. a plot can not be shown and saved to a file at the same time!
845
846 The default is None.
847 ax : `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_, optional
848 If not None, 'axes' is used for plotting.
849
850 The default is None.
851
852 Returns
853 -------
854 `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_
855 Plot.
856 """
857 if not isinstance(y_pred, np.ndarray):
858 raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
859 f"but not of '{type(y_pred)}'")
860 if not isinstance(y, np.ndarray):
861 raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
862 f"but not of '{type(y)}'")
863 if y_pred.shape != y.shape:
864 raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
865 if len(y_pred.shape) != 1:
866 raise ValueError("'y_pred' must be a 1d array")
867 if len(y.shape) != 1:
868 raise ValueError("'y' must be a 1d array")
869 if x_axis_label is not None:
870 if not isinstance(x_axis_label, str):
871 raise TypeError("'x_axis_label' must be an instance of 'str' " +
872 f"but not of '{type(x_axis_label)}'")
873 if y_axis_label is not None:
874 if not isinstance(y_axis_label, str):
875 raise TypeError("'y_axis_label' must be an instance of 'str' " +
876 f"but not of '{type(y_axis_label)}'")
877 if y_ticks is not None:
878 if len(y_ticks) != 2:
879 raise ValueError("'y_ticks' must be a tuple ticks (numbers) and labels (strings)")
880 if not isinstance(show, bool):
881 raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'")
882 if save_to_file is not None:
883 if show is True:
884 raise ValueError("'show' must be False if 'save_to_file' is set")
885
886 if not isinstance(save_to_file, str):
887 raise TypeError("'save_to_file' must be an instance of 'str' but not of " +
888 f"'{type(save_to_file)}'")
889 if ax is not None:
890 if not isinstance(ax, matplotlib.axes.Axes):
891 raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" +
892 f"but not of '{type(ax)}'")
893
894 fig = None
895 if ax is None:
896 fig, ax = plt.subplots()
897
898 if confidence_interval is not None:
899 ax.fill_between(range(len(y_pred)),
900 y_pred - confidence_interval[0],
901 y_pred + confidence_interval[1],
902 alpha=0.5)
903 ax.plot(y, ".-", label="Ground truth")
904 ax.plot(y_pred, ".-", label="Prediction")
905 ax.legend()
906
907 if x_axis_label is not None:
908 ax.set_xlabel(x_axis_label)
909 if y_axis_label is not None:
910 ax.set_ylabel(y_axis_label)
911 if y_ticks is not None:
912 yticks_pos, yticks_labels = y_ticks
913 ax.set_yticks(yticks_pos, labels=yticks_labels)
914
915 if show is True and fig is not None:
916 plt.show()
917 if save_to_file is not None:
918 folder_path = str(Path(save_to_file).parent.absolute())
919 create_path_if_not_exist(folder_path)
920
921 if fig is None:
922 plt.savefig(save_to_file, bbox_inches='tight')
923 else:
924 fig.savefig(save_to_file, bbox_inches='tight')
925
926 return ax
927
928
[docs]
929def robust_download(download_path: str, urls: str | list,
930 verbose: bool = True, timeout: int = 30) -> None:
931 """
932 Downloads a file from the given urls if it does not already exist in the
933 given path. The urls are tried in order. If a download stops or stalls,
934 the next url is tried until one succeeds or all urls have failed.
935
936 Parameters
937 ----------
938 download_path : `str`
939 Local path to the file -- if this path does not exist, the file will be
940 downloaded from the provided 'urls' and stored there.
941 urls : `list` or `str`
942 One url or a list of urls (where additional urls function as backup) to
943 download the file from.
944 verbose : `bool`, optional
945 If True, a progress bar is shown while downloading the file.
946
947 The default is True.
948 timeout : `int`
949 If this time passed without progress while downloading, the download is
950 considered failed.
951
952 The default is 30 seconds.
953 """
954 if isinstance(urls, str):
955 urls = [urls]
956
957 for url in urls:
958 try:
959 download_if_necessary(download_path, url, verbose, timeout)
960 return
961 except Exception as e:
962 print(f"Failed url: {url} with {e}")
963 continue
964
965 raise SystemError("All download attempts failed")
966
967
968def _download_process(download_path: str, url: str, backup_urls: list[str],
969 last_update: mp.Value, stop_flag: mp.Value,
970 finish_flag: mp.Value, verbose: bool) -> None:
971 """
972 Process that handles the actual download. It updates the last download
973 update variable and cleans up the corrupted file if the download fails from
974 within.
975
976 This function is only to be called from `download_if_necessary`.
977
978 Parameters
979 ----------
980 download_path : `str`
981 Local path to the file -- if this path does not exist, the file will be
982 downloaded from the provided 'urls' and stored there.
983 url : `str`
984 Web-URL pointing to the source the file should be downloaded from. Can
985 also point to a google drive file.
986 backup_urls : `list[str]`
987 List of alternative URLs that are being tried in the case that
988 downloading from 'url' fails. This is deprecated, but left in for
989 downward compatibility with `download_if_necessary` calls with
990 backup_urls. Not necessary when using `robust_download`.
991 last_update : `mp.Value`
992 Shared variable to keep track of the last successful download update.
993 stop_flag : `mp.Value`
994 Shared variable. Set to 1 when this process stopped by finishing or
995 error.
996 finish_flag : `mp.Value`
997 Shared variable. Set to 1 when download finished successfully.
998 verbose : `bool`
999 If True, a progress bar is shown while downloading the file.
1000 """
1001 try:
1002 progress_bar = None
1003 response = None
1004
1005 if "drive.google.com" in url:
1006 session = requests.Session()
1007 response = session.get(url)
1008 html = response.text
1009
1010 def extract(pattern):
1011 match = re.search(pattern, html)
1012 return match.group(1) if match else None
1013
1014 file_id = extract(r'name="id" value="([^"]+)"')
1015 file_confirm = extract(r'name="confirm" value="([^"]+)"')
1016 file_uuid = extract(r'name="uuid" value="([^"]+)"')
1017
1018 if not all([file_id, file_confirm, file_uuid]):
1019 raise SystemError("Failed to extract download parameters")
1020
1021 download_url = (
1022 f"https://drive.usercontent.google.com/download"
1023 f"?id={file_id}&export=download&confirm={file_confirm}&uuid={file_uuid}"
1024 )
1025
1026 response = session.get(download_url, stream=True)
1027
1028 else:
1029 response = requests.get(url, stream=True, allow_redirects=True,
1030 timeout=1000)
1031
1032 # Deprecated, left in for backward compatibility
1033 if response.status_code != 200:
1034 for backup_url in backup_urls:
1035 response = requests.get(backup_url, stream=verbose,
1036 allow_redirects=True, timeout=1000)
1037 if response.status_code == 200:
1038 break
1039 if response.status_code != 200:
1040 raise SystemError(f"Failed to download -- {response.status_code}")
1041
1042 content_length = int(response.headers.get("content-length", 0))
1043 with open(download_path, "wb") as file:
1044 progress_bar = False
1045 if verbose:
1046 progress_bar = tqdm(desc=download_path, total=content_length,
1047 ascii=True, unit='B', unit_scale=True,
1048 unit_divisor=1024)
1049 for data in response.iter_content(chunk_size=1024):
1050 size = file.write(data)
1051 if progress_bar:
1052 progress_bar.update(size)
1053 with last_update.get_lock():
1054 last_update.value = time.time()
1055 with finish_flag.get_lock():
1056 finish_flag.value = 1
1057 with stop_flag.get_lock():
1058 stop_flag.value = 1
1059 finally:
1060 if progress_bar:
1061 progress_bar.close()
1062 if response:
1063 response.close()
1064 with finish_flag.get_lock():
1065 if os.path.exists(download_path) and finish_flag.value == 0:
1066 os.remove(download_path)
1067 with stop_flag.get_lock():
1068 stop_flag.value = 1
1069
1070
[docs]
1071@deprecated(reason="Please use new function `robust_download` instead.")
1072def download_from_gdrive_if_necessary(download_path: str, url: str, verbose: bool = True) -> None:
1073 """
1074 Downloads a file from a google drive repository if it does not already exist
1075 in a given path.
1076
1077 Note that if the path (folder) does not already exist, it will be created.
1078
1079 Parameters
1080 ----------
1081 download_path : `str`
1082 Local path to the file -- if this path does not exist, the file will be downloaded from
1083 the provided 'url' and stored in 'download_dir'.
1084 url : `str`
1085 Web-URL of the google drive repository.
1086 verbose : `bool`, optional
1087 If True, a progress bar is shown while downloading the file.
1088
1089 The default is True.
1090 """
1091 folder_path = str(Path(download_path).parent.absolute())
1092 create_path_if_not_exist(folder_path)
1093
1094 if not os.path.isfile(download_path):
1095 session = requests.Session()
1096
1097 response = session.get(url)
1098 html = response.text
1099
1100 def extract(pattern):
1101 match = re.search(pattern, html)
1102 return match.group(1) if match else None
1103
1104 file_id = extract(r'name="id" value="([^"]+)"')
1105 file_confirm = extract(r'name="confirm" value="([^"]+)"')
1106 file_uuid = extract(r'name="uuid" value="([^"]+)"')
1107
1108 if not all([file_id, file_confirm, file_uuid]):
1109 raise SystemError("Failed to extract download parameters")
1110
1111 download_url = (
1112 f"https://drive.usercontent.google.com/download"
1113 f"?id={file_id}&export=download&confirm={file_confirm}&uuid={file_uuid}"
1114 )
1115
1116 response = session.get(download_url, stream=True)
1117
1118 if response.status_code != 200:
1119 raise SystemError(f"Failed to download -- {response.status_code}")
1120
1121 if verbose is True:
1122 content_length = int(response.headers.get('content-length', 0))
1123 with open(download_path, "wb") as file, tqdm(desc=download_path,
1124 total=content_length,
1125 ascii=True,
1126 unit='B',
1127 unit_scale=True,
1128 unit_divisor=1024) as progress_bar:
1129 for data in response.iter_content(chunk_size=1024):
1130 size = file.write(data)
1131 progress_bar.update(size)
1132 else:
1133 with open(download_path, "wb") as f_out:
1134 f_out.write(response.content)
1135
1136
[docs]
1137def download_if_necessary(download_path: str, url: str, verbose: bool = True,
1138 backup_urls: list[str] = [], timeout: int = 30) -> None:
1139 """
1140 Downloads a file from a given URL if it does not already exist in a given
1141 path. This function is deprecated, please use `robust_download` instead.
1142
1143 Note that if the path (folder) does not already exist, it will be created.
1144
1145 Parameters
1146 ----------
1147 download_path : `str`
1148 Local path to the file -- if this path does not exist, the file will be
1149 downloaded from the provided 'url' and stored in 'download_dir'.
1150 url : `str`
1151 Web-URL.
1152 verbose : `bool`, optional
1153 If True, a progress bar is shown while downloading the file.
1154
1155 The default is True.
1156 backup_urls : `list[str]`, optional
1157 List of alternative URLs that are being tried in the case that downloading from 'url' fails.
1158
1159 The default is an empty list.
1160 timeout : `int`, optional
1161 Allowed download inactivity in seconds. After this time passed without
1162 an update, the donwload is considered failed.
1163
1164 The default is 30 seconds.
1165 """
1166 folder_path = str(Path(download_path).parent.absolute())
1167 create_path_if_not_exist(folder_path)
1168
1169 if os.path.isfile(download_path):
1170 return
1171
1172 last_update = mp.Value('d', time.time())
1173 stop_flag = mp.Value('i', 0)
1174 finish_flag = mp.Value('i', 0)
1175
1176 t = mp.Process(target=_download_process, args=(download_path, url, backup_urls, last_update, stop_flag, finish_flag, verbose))
1177 t.start()
1178
1179 while True:
1180 time.sleep(1)
1181 with last_update.get_lock():
1182 idle = time.time() - last_update.value
1183 with stop_flag.get_lock():
1184 if stop_flag.value == 1:
1185 with finish_flag.get_lock():
1186 if finish_flag.value == 1:
1187 break
1188 else:
1189 if os.path.exists(download_path) and finish_flag.value == 0:
1190 os.remove(download_path)
1191 raise SystemError(f"failed downloading from {url}")
1192 if idle > timeout:
1193 with finish_flag.get_lock():
1194 t.terminate()
1195 t.join()
1196 if os.path.exists(download_path) and finish_flag.value == 0:
1197 os.remove(download_path)
1198 raise SystemError(f"no progress in {timeout} seconds, aborting download")
1199 t.join()
1200
1201
[docs]
1202def create_path_if_not_exist(path_in: str) -> None:
1203 """
1204 Creates a directory and all its parent directories if they do not already exist.
1205
1206 Parameters
1207 ----------
1208 path_in : `str`
1209 Path to be created.
1210 """
1211 Path(path_in).mkdir(parents=True, exist_ok=True)
1212
1213
[docs]
1214def pack_zip_archive(f_in: list[str], f_out: str) -> None:
1215 """
1216 Compresses a given list of files into a .zip archive.
1217
1218 Parameters
1219 ----------
1220 f_in : `list[str]`
1221 List of files to be compressed into the .zip archive.
1222 f_out : `str`
1223 Path to the final .zip file.
1224 """
1225 with zipfile.ZipFile(f_out, "w") as f_zip_out:
1226 for f_cur_in in f_in:
1227 f_zip_out.write(f_cur_in, compress_type=zipfile.ZIP_DEFLATED)
1228
1229
[docs]
1230def unpack_zip_archive(f_in: str, folder_out: str) -> None:
1231 """
1232 Unpacks a .zip archive.
1233
1234 Parameters
1235 ----------
1236 f_in : `str`
1237 Path to the .zip file.
1238 folder_out : `str`
1239 Path to the folder where the unpacked files will be stored.
1240 """
1241 with zipfile.ZipFile(f_in, "r") as f:
1242 f.extractall(folder_out)
1243
1244
[docs]
1245def get_temp_folder() -> str:
1246 """
1247 Gets a path to a temporary folder -- i.e. a folder for storing temporary files.
1248
1249 Returns
1250 -------
1251 `str`
1252 Path to a temporary folder.
1253 """
1254 return tempfile.gettempdir()
1255
1256
[docs]
1257def to_seconds(days: int = None, hours: int = None, minutes: int = None) -> int:
1258 """
1259 Converts a timestamp (i.e. days, hours, minutes) into seconds.
1260
1261 Parameters
1262 ----------
1263 days : `int`, optional
1264 Days.
1265 hours : `int`, optional
1266 Hours.
1267 minutes : `int`, optional
1268 Minutes.
1269
1270 Returns
1271 -------
1272 `int`
1273 Timestamp in seconds.
1274 """
1275 sec = 0
1276
1277 if days is not None:
1278 if not isinstance(days, int):
1279 raise TypeError(f"'days' must be an instance of 'int' but not of {type(days)}")
1280 if days <= 0:
1281 raise ValueError("'days' must be positive")
1282
1283 sec += 24*60*60 * days
1284 if hours is not None:
1285 if not isinstance(hours, int):
1286 raise TypeError(f"'hours' must be an instance of 'int' but not of {type(hours)}")
1287 if hours <= 0:
1288 raise ValueError("'hours' must be positive")
1289
1290 sec += 60*60 * hours
1291 if minutes is not None:
1292 if not isinstance(minutes, int):
1293 raise TypeError(f"'minutes' must be an instance of 'int' but not of {type(minutes)}")
1294 if minutes <= 0:
1295 raise ValueError("'minutes' must be positive")
1296
1297 sec += 60 * minutes
1298
1299 return sec