Source code for epyt_flow.utils

   1"""
   2Module provides helper functions.
   3"""
   4import os
   5import math
   6import tempfile
   7import zipfile
   8from pathlib import Path
   9import re
  10import requests
  11import time
  12import multiprocessing as mp
  13from deprecated import deprecated
  14from tqdm import tqdm
  15import numpy as np
  16import matplotlib
  17import matplotlib.pyplot as plt
  18from epanet_plus import EpanetConstants
  19
  20
  21AREA_UNIT_FT2 = 1
  22AREA_UNIT_M2 = 2
  23AREA_UNIT_CM2 = 3
  24MASS_UNIT_MG = 4
  25MASS_UNIT_UG = 5
  26MASS_UNIT_MOL = 6
  27MASS_UNIT_MMOL = 7
  28TIME_UNIT_HRS = 8
  29MASS_UNIT_CUSTOM = 9
  30
  31
[docs] 32def pressureunit_to_str(unit_id: int) -> str: 33 """ 34 Converts a given pressure unit ID to the corresponding description. 35 36 Parameters 37 ---------- 38 unit_id : `int` 39 Pressure unit ID. 40 41 Must be one of the following EPANET constants: 42 43 - EN_PSI = 0 (Pounds per square inch) 44 - EN_KPA = 1 (Kilopascals) 45 - EN_METERS = 2 (Meters) 46 - EN_BAR = 3 (Bar) 47 - EN_FEET = 4 (Feet) 48 49 Returns 50 ------- 51 `str` 52 Pressure unit description. 53 """ 54 if unit_id is None: 55 return "" 56 elif unit_id == EpanetConstants.EN_PSI: 57 return "psi" 58 elif unit_id == EpanetConstants.EN_KPA: 59 return "kilopascal" 60 elif unit_id == EpanetConstants.EN_METERS: 61 return "meter" 62 elif unit_id == EpanetConstants.EN_BAR: 63 return "bar" 64 elif unit_id == EpanetConstants.EN_FEET: 65 return "feet" 66 else: 67 raise ValueError(f"Unknown unit ID '{unit_id}'")
68 69
[docs] 70def flowunit_to_str(unit_id: int) -> str: 71 """ 72 Converts a given flow unit ID to the corresponding description. 73 74 Parameters 75 ---------- 76 unit_id : `int` 77 Flow unit ID. 78 79 Must be one of the following EPANET constants: 80 81 - EN_CFS = 0 (cubic foot/sec) 82 - EN_GPM = 1 (gal/min) 83 - EN_MGD = 2 (Million gal/day) 84 - EN_IMGD = 3 (Imperial MGD) 85 - EN_AFD = 4 (ac-foot/day) 86 - EN_LPS = 5 (liter/sec) 87 - EN_LPM = 6 (liter/min) 88 - EN_MLD = 7 (Megaliter/day) 89 - EN_CMH = 8 (cubic meter/hr) 90 - EN_CMD = 9 (cubic meter/day) 91 - EN_CMS = 10 (cubic meter/sec) 92 93 Returns 94 ------- 95 `str` 96 Flow unit description. 97 """ 98 if unit_id is None: 99 return "" 100 elif unit_id == EpanetConstants.EN_CFS: 101 return "cubic foot/sec" 102 elif unit_id == EpanetConstants.EN_GPM: 103 return "gal/min" 104 elif unit_id == EpanetConstants.EN_MGD: 105 return "Million gal/day" 106 elif unit_id == EpanetConstants.EN_IMGD: 107 return "Imperial MGD" 108 elif unit_id == EpanetConstants.EN_AFD: 109 return "ac-foot/day" 110 elif unit_id == EpanetConstants.EN_LPS: 111 return "liter/sec" 112 elif unit_id == EpanetConstants.EN_LPM: 113 return "liter/min" 114 elif unit_id == EpanetConstants.EN_MLD: 115 return "Megaliter/day" 116 elif unit_id == EpanetConstants.EN_CMH: 117 return "cubic meter/hr" 118 elif unit_id == EpanetConstants.EN_CMD: 119 return "cubic meter/day" 120 elif unit_id == EpanetConstants.EN_CMS: 121 return "cubic meter/sec" 122 else: 123 raise ValueError(f"Unknown unit ID '{unit_id}'")
124 125
[docs] 126def areaunit_to_id(unit_desc: str) -> int: 127 """ 128 Converts a given area units string to the corresponding ID. 129 130 Parameters 131 ---------- 132 unit_desc : `str` 133 Area units string. 134 135 Returns 136 ------- 137 `int` 138 Corresponding area unit ID. 139 """ 140 return {"FT2": AREA_UNIT_FT2, 141 "M2": AREA_UNIT_M2, 142 "CM2": AREA_UNIT_CM2}[unit_desc]
143 144
[docs] 145def massunit_to_id(unit_desc: str) -> int: 146 """ 147 Converts a given mass units string to the corresponding ID. 148 149 Parameters 150 ---------- 151 unit_desc : `str` 152 Mass units string. 153 154 Returns 155 ------- 156 `int` 157 Corresponding mass unit ID. 158 """ 159 mass_unit_dict = {"MG": MASS_UNIT_MG, 160 "UG": MASS_UNIT_UG, 161 "MOL": MASS_UNIT_MOL, 162 "MMOL": MASS_UNIT_MMOL} 163 164 if unit_desc in mass_unit_dict: 165 return mass_unit_dict[unit_desc] 166 else: 167 return MASS_UNIT_CUSTOM
168 169
[docs] 170def qualityunit_to_id(unit_desc: str) -> int: 171 """ 172 Converts a given measurement unit description to the corresponding mass unit ID. 173 174 Parameters 175 ---------- 176 unit_desc : `str` 177 Mass unit. 178 179 Returns 180 ------- 181 `int` 182 Mass unit ID. 183 184 Will be either None (if no water quality analysis was set up) or 185 one of the following constants: 186 187 - MASS_UNIT_MG = 4 (mg/L) 188 - MASS_UNIT_UG = 5 (ug/L) 189 - TIME_UNIT_HRS = 8 (hrs) 190 """ 191 if unit_desc == "mg/L": 192 return MASS_UNIT_MG 193 elif unit_desc == "ug/L": 194 return MASS_UNIT_UG 195 elif unit_desc == "hrs": 196 return TIME_UNIT_HRS 197 else: 198 return None
199 200
[docs] 201def massunit_to_str(unit_id: int) -> str: 202 """ 203 Converts a given mass unit ID to the corresponding description. 204 205 Parameters 206 ---------- 207 unit_id : `int` 208 ID of the mass unit. 209 210 Must be one of the following constant: 211 212 - MASS_UNIT_MG = 4 213 - MASS_UNIT_UG = 5 214 - MASS_UNIT_MOL = 6 215 - MASS_UNIT_MMOL = 7 216 - MASS_UNIT_CUSTOM = 9 217 218 Returns 219 ------- 220 `str` 221 Mass unit description. 222 """ 223 if unit_id is None: 224 return "" 225 elif unit_id == MASS_UNIT_MG: 226 return "MG" 227 elif unit_id == MASS_UNIT_UG: 228 return "UG" 229 elif unit_id == MASS_UNIT_MOL: 230 return "MOL" 231 elif unit_id == MASS_UNIT_MMOL: 232 return "MMOL" 233 elif unit_id == MASS_UNIT_CUSTOM: 234 return "CUSTOM UNIT" 235 else: 236 raise ValueError(f"Unknown mass unit ID '{unit_id}'")
237 238
[docs] 239def qualityunit_to_str(unit_id: int) -> str: 240 """ 241 Converts a given quality measurement unit ID to the corresponding description. 242 243 Parameters 244 ---------- 245 unit_id : `int` 246 ID of the quality unit. 247 248 Must be one of the following constants: 249 250 - MASS_UNIT_MG = 4 (mg/L) 251 - MASS_UNIT_UG = 5 (ug/L) 252 - TIME_UNIT_HRS = 8 (hrs) 253 254 Returns 255 ------- 256 `str` 257 Mass unit description. 258 """ 259 if unit_id is None: 260 return "" 261 elif unit_id == MASS_UNIT_MG: 262 return "mg/L" 263 elif unit_id == MASS_UNIT_UG: 264 return "ug/L" 265 elif unit_id == TIME_UNIT_HRS: 266 return "hrs" 267 else: 268 raise ValueError(f"Unknown unit ID '{unit_id}'")
269 270
[docs] 271def areaunit_to_str(unit_id: int) -> str: 272 """ 273 Converts a given area measurement unit ID to the corresponding description. 274 275 Parameters 276 ---------- 277 unit_id : `int` 278 ID of the area unit. 279 280 Must be one of the following constants: 281 282 - AREA_UNIT_FT2 = 1 283 - AREA_UNIT_M2 = 2 284 - AREA_UNIT_CM2 = 3 285 286 Returns 287 ------- 288 `str` 289 Area unit description. 290 """ 291 if unit_id is None: 292 return None 293 elif unit_id == AREA_UNIT_FT2: 294 return "FT2" 295 elif unit_id == AREA_UNIT_M2: 296 return "M2" 297 elif unit_id == AREA_UNIT_CM2: 298 return "CM2" 299 else: 300 raise ValueError(f"Unknown unit ID '{unit_id}'")
301 302
[docs] 303def is_flowunit_simetric(unit_id: int) -> bool: 304 """ 305 Checks if a given flow unit belongs to SI metric units. 306 307 Parameters 308 ---------- 309 unit_id : `int` 310 ID of the flow unit. 311 312 Must be one of the following EPANET constants: 313 314 - EN_CFS = 0 (cubic foot/sec) 315 - EN_GPM = 1 (gal/min) 316 - EN_MGD = 2 (Million gal/day) 317 - EN_IMGD = 3 (Imperial MGD) 318 - EN_AFD = 4 (ac-foot/day) 319 - EN_LPS = 5 (liter/sec) 320 - EN_LPM = 6 (liter/min) 321 - EN_MLD = 7 (Megaliter/day) 322 - EN_CMH = 8 (cubic meter/hr) 323 - EN_CMD = 9 (cubic meter/day) 324 - EN_CMS = 10 (cubic meter/sec) 325 326 Returns 327 ------- 328 `bool` 329 True if the fiven unit is a SI metric unit, False otherwise. 330 """ 331 return unit_id in [EpanetConstants.EN_LPS, EpanetConstants.EN_LPM, EpanetConstants.EN_MLD, 332 EpanetConstants.EN_CMH, EpanetConstants.EN_CMD, EpanetConstants.EN_CMS]
333 334 335def _get_pressure_convert_factor(new_unit_id: int, old_unit: int) -> float: 336 if new_unit_id == old_unit: 337 return 1. 338 339 if new_unit_id == EpanetConstants.EN_BAR: 340 if old_unit == EpanetConstants.EN_PSI: 341 return .0689476 342 elif old_unit == EpanetConstants.EN_METERS: 343 return .09804139432 344 elif old_unit == EpanetConstants.EN_FEET: 345 return .029883016988736 346 elif old_unit == EpanetConstants.EN_KPA: 347 return .01 348 349 elif new_unit_id == EpanetConstants.EN_KPA: 350 if old_unit == EpanetConstants.EN_PSI: 351 return 6.894744825494 352 elif old_unit == EpanetConstants.EN_METERS: 353 return 9.804139432 354 elif old_unit == EpanetConstants.EN_FEET: 355 return 2.9890669 356 elif old_unit == EpanetConstants.EN_BAR: 357 return 100. 358 359 elif new_unit_id == EpanetConstants.EN_FEET: 360 if old_unit == EpanetConstants.EN_PSI: 361 return 2.3072493927233 362 elif old_unit == EpanetConstants.EN_METERS: 363 return 3.2808398950131 364 elif old_unit == EpanetConstants.EN_KPA: 365 return .33455256555148 366 elif old_unit == EpanetConstants.EN_BAR: 367 return 33.455256555148 368 369 elif new_unit_id == EpanetConstants.EN_PSI: 370 if old_unit == EpanetConstants.EN_KPA: 371 return .14503773773020923 372 elif old_unit == EpanetConstants.EN_METERS: 373 return 1.4219702063247 374 elif old_unit == EpanetConstants.EN_FEET: 375 return .43341651888775 376 elif old_unit == EpanetConstants.EN_BAR: 377 return 14.5038 378 379 elif new_unit_id == EpanetConstants.EN_METERS: 380 if old_unit == EpanetConstants.EN_PSI: 381 return .70325 382 elif old_unit == EpanetConstants.EN_KPA: 383 return .10199773339984 384 elif old_unit == EpanetConstants.EN_FEET: 385 return .3048 386 elif old_unit == EpanetConstants.EN_BAR: 387 return 10.199773339984 388 389 else: 390 raise ValueError("Invalid 'new_unit_id'") 391 392 393def _get_flow_convert_factor(new_unit_id: int, old_unit: int) -> float: 394 if new_unit_id == old_unit: 395 return 1. 396 397 if new_unit_id == EpanetConstants.EN_CFS: 398 if old_unit == EpanetConstants.EN_GPM: 399 return .0022280093 400 elif old_unit == EpanetConstants.EN_MGD: 401 return 1.5472286523 402 elif old_unit == EpanetConstants.EN_IMGD: 403 return 1.8581441347 404 elif old_unit == EpanetConstants.EN_AFD: 405 return .5041666667 406 elif old_unit == EpanetConstants.EN_LPS: 407 return .0353146667 408 elif old_unit == EpanetConstants.EN_LPM: 409 return .0005885778 410 elif old_unit == EpanetConstants.EN_MLD: 411 return .40873456853575 412 elif old_unit == EpanetConstants.EN_CMH: 413 return .0098096296 414 elif old_unit == EpanetConstants.EN_CMD: 415 return .0004087346 416 417 elif new_unit_id == EpanetConstants.EN_GPM: 418 if old_unit == EpanetConstants.EN_CFS: 419 return 448.8325660485 420 elif old_unit == EpanetConstants.EN_MGD: 421 return 694.44444444 422 elif old_unit == EpanetConstants.EN_IMGD: 423 return 833.99300382 424 elif old_unit == EpanetConstants.EN_AFD: 425 return 226.28571429 426 elif old_unit == EpanetConstants.EN_LPS: 427 return 15.850323141 428 elif old_unit == EpanetConstants.EN_LPM: 429 return .2641720524 430 elif old_unit == EpanetConstants.EN_MLD: 431 return 183.4528141376 432 elif old_unit == EpanetConstants.EN_CMH: 433 return 4.4028675393 434 elif old_unit == EpanetConstants.EN_CMD: 435 return .1834528141 436 437 elif new_unit_id == EpanetConstants.EN_MGD: 438 if old_unit == EpanetConstants.EN_CFS: 439 return .6463168831 440 elif old_unit == EpanetConstants.EN_GPM: 441 return .00144 442 elif old_unit == EpanetConstants.EN_IMGD: 443 return 1.2009499255 444 elif old_unit == EpanetConstants.EN_AFD: 445 return 0.3258514286 446 elif old_unit == EpanetConstants.EN_LPS: 447 return .0228244653 448 elif old_unit == EpanetConstants.EN_LPM: 449 return .0003804078 450 elif old_unit == EpanetConstants.EN_MLD: 451 return .26417205124156 452 elif old_unit == EpanetConstants.EN_CMH: 453 return .0063401293 454 elif old_unit == EpanetConstants.EN_CMD: 455 return .0002641721 456 457 elif new_unit_id == EpanetConstants.EN_IMGD: 458 if old_unit == EpanetConstants.EN_CFS: 459 return .5381713837 460 elif old_unit == EpanetConstants.EN_MGD: 461 return .8326741846 462 elif old_unit == EpanetConstants.EN_GPM: 463 return .0011990508 464 elif old_unit == EpanetConstants.EN_AFD: 465 return .2713280726 466 elif old_unit == EpanetConstants.EN_LPS: 467 return .0190053431 468 elif old_unit == EpanetConstants.EN_LPM: 469 return .0003167557 470 elif old_unit == EpanetConstants.EN_MLD: 471 return .21996924829908776 472 elif old_unit == EpanetConstants.EN_CMH: 473 return .005279262 474 elif old_unit == EpanetConstants.EN_CMD: 475 return .0002199692 476 477 elif new_unit_id == EpanetConstants.EN_AFD: 478 if old_unit == EpanetConstants.EN_CFS: 479 return 1.9834710744 480 elif old_unit == EpanetConstants.EN_MGD: 481 return 3.0688832772 482 elif old_unit == EpanetConstants.EN_GPM: 483 return .0044191919 484 elif old_unit == EpanetConstants.EN_IMGD: 485 return 3.6855751432 486 elif old_unit == EpanetConstants.EN_LPS: 487 return .0700456199 488 elif old_unit == EpanetConstants.EN_LPM: 489 return .001167427 490 elif old_unit == EpanetConstants.EN_MLD: 491 return .81070995093708 492 elif old_unit == EpanetConstants.EN_CMH: 493 return .0194571167 494 elif old_unit == EpanetConstants.EN_CMD: 495 return .0008107132 496 497 elif new_unit_id == EpanetConstants.EN_LPS: 498 if old_unit == EpanetConstants.EN_CFS: 499 return 28.316846592 500 elif old_unit == EpanetConstants.EN_MGD: 501 return 43.812636389 502 elif old_unit == EpanetConstants.EN_IMGD: 503 return 52.616782407 504 elif old_unit == EpanetConstants.EN_GPM: 505 return .0630901964 506 elif old_unit == EpanetConstants.EN_AFD: 507 return 14.276410157 508 elif old_unit == EpanetConstants.EN_LPM: 509 return .0166666667 510 elif old_unit == EpanetConstants.EN_MLD: 511 return 11.574074074074 512 elif old_unit == EpanetConstants.EN_CMH: 513 return .2777777778 514 elif old_unit == EpanetConstants.EN_CMD: 515 return .0115740741 516 517 elif new_unit_id == EpanetConstants.EN_LPM: 518 if old_unit == EpanetConstants.EN_CFS: 519 return 1699.0107955 520 elif old_unit == EpanetConstants.EN_MGD: 521 return 2628.7581833 522 elif old_unit == EpanetConstants.EN_IMGD: 523 return 3157.0069444 524 elif old_unit == EpanetConstants.EN_AFD: 525 return 856.58460941 526 elif old_unit == EpanetConstants.EN_LPS: 527 return 60 528 elif old_unit == EpanetConstants.EN_GPM: 529 return 3.785411784 530 elif old_unit == EpanetConstants.EN_MLD: 531 return 694.44444444443 532 elif old_unit == EpanetConstants.EN_CMH: 533 return 16.666666667 534 elif old_unit == EpanetConstants.EN_CMD: 535 return 0.6944444444 536 537 elif new_unit_id == EpanetConstants.EN_MLD: 538 if old_unit == EpanetConstants.EN_CFS: 539 return 2.4465755456688 540 elif old_unit == EpanetConstants.EN_MGD: 541 return 3.7854117999999777 542 elif old_unit == EpanetConstants.EN_IMGD: 543 return 4.54609 544 elif old_unit == EpanetConstants.EN_AFD: 545 return 1.2334867714947 546 elif old_unit == EpanetConstants.EN_LPS: 547 return .0864 548 elif old_unit == EpanetConstants.EN_LPM: 549 return .00144 550 elif old_unit == EpanetConstants.EN_GPM: 551 return .00545099296896 552 elif old_unit == EpanetConstants.EN_CMH: 553 return .024 554 elif old_unit == EpanetConstants.EN_CMD: 555 return .00099999999999999 556 557 elif new_unit_id == EpanetConstants.EN_CMH: 558 if old_unit == EpanetConstants.EN_CFS: 559 return 101.94064773 560 elif old_unit == EpanetConstants.EN_MGD: 561 return 157.725491 562 elif old_unit == EpanetConstants.EN_IMGD: 563 return 189.42041667 564 elif old_unit == EpanetConstants.EN_AFD: 565 return 51.395076564 566 elif old_unit == EpanetConstants.EN_LPS: 567 return 3.6 568 elif old_unit == EpanetConstants.EN_LPM: 569 return .06 570 elif old_unit == EpanetConstants.EN_MLD: 571 return 41.666666666666 572 elif old_unit == EpanetConstants.EN_GPM: 573 return .227124707 574 elif old_unit == EpanetConstants.EN_CMD: 575 return 0.0416666667 576 577 elif new_unit_id == EpanetConstants.EN_CMD: 578 if old_unit == EpanetConstants.EN_CFS: 579 return 2446.5755455 580 elif old_unit == EpanetConstants.EN_MGD: 581 return 3785.411784 582 elif old_unit == EpanetConstants.EN_IMGD: 583 return 4546.09 584 elif old_unit == EpanetConstants.EN_AFD: 585 return 1233.4818375 586 elif old_unit == EpanetConstants.EN_LPS: 587 return 86.4 588 elif old_unit == EpanetConstants.EN_LPM: 589 return 1.44 590 elif old_unit == EpanetConstants.EN_MLD: 591 return 1000. 592 elif old_unit == EpanetConstants.EN_CMH: 593 return 24 594 elif old_unit == EpanetConstants.EN_GPM: 595 return 5.450992969 596 597
[docs] 598def time_points_to_one_hot_encoding(time_points: list[int], total_length: int) -> list[int]: 599 """ 600 Converts a list of time points into a one-hot-encoding. 601 602 Parameters 603 ---------- 604 time_points : `list[int]` 605 Time points to be one-hot-encoded. 606 total_length : `int` 607 Length of final one-hot-encoding. 608 609 Returns 610 ------- 611 `list[int]` 612 One-hot-encoded time points. 613 """ 614 results = [0] * total_length 615 616 for t in time_points: 617 results[t] = 1 618 619 return results
620 621
[docs] 622def volume_to_level(tank_volume: float, tank_diameter: float) -> float: 623 """ 624 Computes the water level in a tank containing a given volume of water. 625 626 Parameters 627 ---------- 628 tank_volume : `float` 629 Water volume in the tank. 630 tank_diameter : `float` 631 Diameter of the tank. 632 633 Returns 634 ------- 635 `float` 636 Water level in tank. 637 """ 638 if not isinstance(tank_volume, float): 639 raise TypeError("'tank_volume' must be an instace of 'float' " + 640 f"but not of '{type(tank_volume)}'") 641 if tank_volume < 0: 642 raise ValueError("'tank_volume' can not be negative") 643 if not isinstance(tank_diameter, float): 644 raise TypeError("'tank_diameter' must be an instace of 'float' " + 645 f"but not of '{type(tank_diameter)}'") 646 if tank_diameter <= 0: 647 raise ValueError("'tank_diameter' must be greater than zero") 648 649 return (4. / (math.pow(tank_diameter, 2) * math.pi)) * tank_volume
650 651
[docs] 652def level_to_volume(tank_level: float, tank_diameter: float) -> float: 653 """ 654 Computes the volume of water in a tank given the water level in the tank. 655 656 Parameters 657 ---------- 658 tank_level : `float` 659 Water level in the tank. 660 tank_diameter : `float` 661 Diameter of the tank. 662 663 Returns 664 ------- 665 `float` 666 Water volume in tank. 667 """ 668 if not isinstance(tank_level, float): 669 raise TypeError("'tank_level' must be an instace of 'float' " + 670 f"but not of '{type(tank_level)}'") 671 if tank_level < 0: 672 raise ValueError("'tank_level' can not be negative") 673 if not isinstance(tank_diameter, float): 674 raise TypeError("'tank_diameter' must be an instace of 'float' " + 675 f"but not of '{type(tank_diameter)}'") 676 if tank_diameter <= 0: 677 raise ValueError("'tank_diameter' must be greater than zero") 678 679 return tank_level * math.pow(0.5 * tank_diameter, 2) * math.pi
680 681
[docs] 682def plot_timeseries_data(data: np.ndarray, labels: list[str] = None, x_axis_label: str = None, 683 y_axis_label: str = None, y_ticks: tuple[list[float], list[str]] = None, 684 show: bool = True, save_to_file: str = None, 685 ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes: 686 """ 687 Plots a single or multiple time series. 688 689 Parameters 690 ---------- 691 data : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 692 Time series data -- each row in `data` corresponds to a complete time series. 693 labels : `list[str]`, optional 694 Labels for each time series in `data`. 695 If None, no labels are shown. 696 697 The default is None. 698 x_axis_label : `str`, optional 699 X axis label. 700 701 The default is None. 702 y_axis_label : `str`, optional 703 Y axis label. 704 705 The default is None. 706 y_ticks: `(list[float], list[str])`, optional 707 Tuple of ticks (numbers) and labels (strings) for the y-axis. 708 709 The default is None. 710 show : `bool`, optional 711 If True, the plot/figure is shown in a window. 712 713 Only considered when 'ax' is None. 714 715 The default is True. 716 save_to_file : `str`, optional 717 File to which the plot is saved. 718 719 If specified, 'show' must be set to False -- 720 i.e. a plot can not be shown and saved to a file at the same time! 721 722 The default is None. 723 ax : `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_, optional 724 If not None, 'ax' is used for plotting. 725 726 The default is None. 727 728 Returns 729 ------- 730 `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_ 731 Plot. 732 """ 733 if not isinstance(data, np.ndarray): 734 raise TypeError(f"'data' must be an instance of 'numpy.ndarray' but not of '{type(data)}'") 735 if len(data.shape) != 2: 736 raise ValueError("'data' must be a 2d array where each row corresponds to a time series " + 737 "-- use '.reshape(1, -1)' in case of single time series") 738 if labels is not None: 739 if not isinstance(labels, list) or not all(isinstance(label, str) for label in labels): 740 raise TypeError("'labels' must be a instance of 'list[str]'") 741 if x_axis_label is not None: 742 if not isinstance(x_axis_label, str): 743 raise TypeError("'x_axis_label' must be an instance of 'str' " + 744 f"but not of '{type(x_axis_label)}'") 745 if y_axis_label is not None: 746 if not isinstance(y_axis_label, str): 747 raise TypeError("'y_axis_label' must be an instance of 'str' " + 748 f"but not of '{type(y_axis_label)}'") 749 if y_ticks is not None: 750 if len(y_ticks) != 2: 751 raise ValueError("'y_ticks' must be a tuple ticks (numbers) and labels (strings)") 752 if not isinstance(show, bool): 753 raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'") 754 if save_to_file is not None: 755 if show is True: 756 raise ValueError("'show' must be False if 'save_to_file' is set") 757 758 if not isinstance(save_to_file, str): 759 raise TypeError("'save_to_file' must be an instance of 'str' but not of " + 760 f"'{type(save_to_file)}'") 761 if ax is not None: 762 if not isinstance(ax, matplotlib.axes.Axes): 763 raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" + 764 f"but not of '{type(ax)}'") 765 766 fig = None 767 if ax is None: 768 fig, ax = plt.subplots() 769 770 labels = labels if labels is not None else [None] * data.shape[0] 771 772 for i in range(data.shape[0]): 773 ax.plot(data[i, :], ".-", label=labels[i]) 774 775 if not any(label is None for label in labels): 776 ax.legend() 777 778 if x_axis_label is not None: 779 ax.set_xlabel(x_axis_label) 780 if y_axis_label is not None: 781 ax.set_ylabel(y_axis_label) 782 if y_ticks is not None: 783 yticks_pos, yticks_labels = y_ticks 784 ax.set_yticks(yticks_pos, labels=yticks_labels) 785 786 if show is True and fig is not None: 787 plt.show() 788 if save_to_file is not None: 789 folder_path = str(Path(save_to_file).parent.absolute()) 790 create_path_if_not_exist(folder_path) 791 792 if fig is None: 793 plt.savefig(save_to_file, bbox_inches='tight') 794 else: 795 fig.savefig(save_to_file, bbox_inches='tight') 796 797 return ax
798 799
[docs] 800def plot_timeseries_prediction(y: np.ndarray, y_pred: np.ndarray, 801 confidence_interval: np.ndarray = None, 802 x_axis_label: str = None, y_axis_label: str = None, 803 y_ticks: tuple[list[float], list[str]] = None, 804 show: bool = True, save_to_file: str = None, 805 ax: matplotlib.axes.Axes = None 806 ) -> matplotlib.axes.Axes: 807 """ 808 Plots the prediction (e.g. forecast) of *single* time series together with the 809 ground truth time series. In addition, confidence intervals can be plotted as well. 810 811 Parameters 812 ---------- 813 y : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 814 Ground truth values. 815 y_pred : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 816 Predicted values. 817 confidence_interval : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional 818 Confidence interval (upper and lower value) for each prediction in `y_pred`. 819 If not None, the confidence interval is plotted as well. 820 821 The default is None. 822 x_axis_label : `str`, optional 823 X axis label. 824 825 The default is None. 826 y_axis_label : `str`, optional 827 Y axis label. 828 829 The default is None. 830 y_ticks: `(list[float], list[str])`, optional 831 Tuple of ticks (numbers) and labels (strings) for the y-axis. 832 833 The default is None. 834 show : `bool`, optional 835 If True, the plot/figure is shown in a window. 836 837 Only considered when 'ax' is None. 838 839 The default is True. 840 save_to_file : `str`, optional 841 File to which the plot is saved. 842 843 If specified, 'show' must be set to False -- 844 i.e. a plot can not be shown and saved to a file at the same time! 845 846 The default is None. 847 ax : `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_, optional 848 If not None, 'axes' is used for plotting. 849 850 The default is None. 851 852 Returns 853 ------- 854 `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html>`_ 855 Plot. 856 """ 857 if not isinstance(y_pred, np.ndarray): 858 raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " + 859 f"but not of '{type(y_pred)}'") 860 if not isinstance(y, np.ndarray): 861 raise TypeError("'y' must be an instance of 'numpy.ndarray' " + 862 f"but not of '{type(y)}'") 863 if y_pred.shape != y.shape: 864 raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}") 865 if len(y_pred.shape) != 1: 866 raise ValueError("'y_pred' must be a 1d array") 867 if len(y.shape) != 1: 868 raise ValueError("'y' must be a 1d array") 869 if x_axis_label is not None: 870 if not isinstance(x_axis_label, str): 871 raise TypeError("'x_axis_label' must be an instance of 'str' " + 872 f"but not of '{type(x_axis_label)}'") 873 if y_axis_label is not None: 874 if not isinstance(y_axis_label, str): 875 raise TypeError("'y_axis_label' must be an instance of 'str' " + 876 f"but not of '{type(y_axis_label)}'") 877 if y_ticks is not None: 878 if len(y_ticks) != 2: 879 raise ValueError("'y_ticks' must be a tuple ticks (numbers) and labels (strings)") 880 if not isinstance(show, bool): 881 raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'") 882 if save_to_file is not None: 883 if show is True: 884 raise ValueError("'show' must be False if 'save_to_file' is set") 885 886 if not isinstance(save_to_file, str): 887 raise TypeError("'save_to_file' must be an instance of 'str' but not of " + 888 f"'{type(save_to_file)}'") 889 if ax is not None: 890 if not isinstance(ax, matplotlib.axes.Axes): 891 raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" + 892 f"but not of '{type(ax)}'") 893 894 fig = None 895 if ax is None: 896 fig, ax = plt.subplots() 897 898 if confidence_interval is not None: 899 ax.fill_between(range(len(y_pred)), 900 y_pred - confidence_interval[0], 901 y_pred + confidence_interval[1], 902 alpha=0.5) 903 ax.plot(y, ".-", label="Ground truth") 904 ax.plot(y_pred, ".-", label="Prediction") 905 ax.legend() 906 907 if x_axis_label is not None: 908 ax.set_xlabel(x_axis_label) 909 if y_axis_label is not None: 910 ax.set_ylabel(y_axis_label) 911 if y_ticks is not None: 912 yticks_pos, yticks_labels = y_ticks 913 ax.set_yticks(yticks_pos, labels=yticks_labels) 914 915 if show is True and fig is not None: 916 plt.show() 917 if save_to_file is not None: 918 folder_path = str(Path(save_to_file).parent.absolute()) 919 create_path_if_not_exist(folder_path) 920 921 if fig is None: 922 plt.savefig(save_to_file, bbox_inches='tight') 923 else: 924 fig.savefig(save_to_file, bbox_inches='tight') 925 926 return ax
927 928
[docs] 929def robust_download(download_path: str, urls: str | list, 930 verbose: bool = True, timeout: int = 30) -> None: 931 """ 932 Downloads a file from the given urls if it does not already exist in the 933 given path. The urls are tried in order. If a download stops or stalls, 934 the next url is tried until one succeeds or all urls have failed. 935 936 Parameters 937 ---------- 938 download_path : `str` 939 Local path to the file -- if this path does not exist, the file will be 940 downloaded from the provided 'urls' and stored there. 941 urls : `list` or `str` 942 One url or a list of urls (where additional urls function as backup) to 943 download the file from. 944 verbose : `bool`, optional 945 If True, a progress bar is shown while downloading the file. 946 947 The default is True. 948 timeout : `int` 949 If this time passed without progress while downloading, the download is 950 considered failed. 951 952 The default is 30 seconds. 953 """ 954 if isinstance(urls, str): 955 urls = [urls] 956 957 for url in urls: 958 try: 959 download_if_necessary(download_path, url, verbose, timeout) 960 return 961 except Exception as e: 962 print(f"Failed url: {url} with {e}") 963 continue 964 965 raise SystemError("All download attempts failed")
966 967 968def _download_process(download_path: str, url: str, backup_urls: list[str], 969 last_update: mp.Value, stop_flag: mp.Value, 970 finish_flag: mp.Value, verbose: bool) -> None: 971 """ 972 Process that handles the actual download. It updates the last download 973 update variable and cleans up the corrupted file if the download fails from 974 within. 975 976 This function is only to be called from `download_if_necessary`. 977 978 Parameters 979 ---------- 980 download_path : `str` 981 Local path to the file -- if this path does not exist, the file will be 982 downloaded from the provided 'urls' and stored there. 983 url : `str` 984 Web-URL pointing to the source the file should be downloaded from. Can 985 also point to a google drive file. 986 backup_urls : `list[str]` 987 List of alternative URLs that are being tried in the case that 988 downloading from 'url' fails. This is deprecated, but left in for 989 downward compatibility with `download_if_necessary` calls with 990 backup_urls. Not necessary when using `robust_download`. 991 last_update : `mp.Value` 992 Shared variable to keep track of the last successful download update. 993 stop_flag : `mp.Value` 994 Shared variable. Set to 1 when this process stopped by finishing or 995 error. 996 finish_flag : `mp.Value` 997 Shared variable. Set to 1 when download finished successfully. 998 verbose : `bool` 999 If True, a progress bar is shown while downloading the file. 1000 """ 1001 try: 1002 progress_bar = None 1003 response = None 1004 1005 if "drive.google.com" in url: 1006 session = requests.Session() 1007 response = session.get(url) 1008 html = response.text 1009 1010 def extract(pattern): 1011 match = re.search(pattern, html) 1012 return match.group(1) if match else None 1013 1014 file_id = extract(r'name="id" value="([^"]+)"') 1015 file_confirm = extract(r'name="confirm" value="([^"]+)"') 1016 file_uuid = extract(r'name="uuid" value="([^"]+)"') 1017 1018 if not all([file_id, file_confirm, file_uuid]): 1019 raise SystemError("Failed to extract download parameters") 1020 1021 download_url = ( 1022 f"https://drive.usercontent.google.com/download" 1023 f"?id={file_id}&export=download&confirm={file_confirm}&uuid={file_uuid}" 1024 ) 1025 1026 response = session.get(download_url, stream=True) 1027 1028 else: 1029 response = requests.get(url, stream=True, allow_redirects=True, 1030 timeout=1000) 1031 1032 # Deprecated, left in for backward compatibility 1033 if response.status_code != 200: 1034 for backup_url in backup_urls: 1035 response = requests.get(backup_url, stream=verbose, 1036 allow_redirects=True, timeout=1000) 1037 if response.status_code == 200: 1038 break 1039 if response.status_code != 200: 1040 raise SystemError(f"Failed to download -- {response.status_code}") 1041 1042 content_length = int(response.headers.get("content-length", 0)) 1043 with open(download_path, "wb") as file: 1044 progress_bar = False 1045 if verbose: 1046 progress_bar = tqdm(desc=download_path, total=content_length, 1047 ascii=True, unit='B', unit_scale=True, 1048 unit_divisor=1024) 1049 for data in response.iter_content(chunk_size=1024): 1050 size = file.write(data) 1051 if progress_bar: 1052 progress_bar.update(size) 1053 with last_update.get_lock(): 1054 last_update.value = time.time() 1055 with finish_flag.get_lock(): 1056 finish_flag.value = 1 1057 with stop_flag.get_lock(): 1058 stop_flag.value = 1 1059 finally: 1060 if progress_bar: 1061 progress_bar.close() 1062 if response: 1063 response.close() 1064 with finish_flag.get_lock(): 1065 if os.path.exists(download_path) and finish_flag.value == 0: 1066 os.remove(download_path) 1067 with stop_flag.get_lock(): 1068 stop_flag.value = 1 1069 1070
[docs] 1071@deprecated(reason="Please use new function `robust_download` instead.") 1072def download_from_gdrive_if_necessary(download_path: str, url: str, verbose: bool = True) -> None: 1073 """ 1074 Downloads a file from a google drive repository if it does not already exist 1075 in a given path. 1076 1077 Note that if the path (folder) does not already exist, it will be created. 1078 1079 Parameters 1080 ---------- 1081 download_path : `str` 1082 Local path to the file -- if this path does not exist, the file will be downloaded from 1083 the provided 'url' and stored in 'download_dir'. 1084 url : `str` 1085 Web-URL of the google drive repository. 1086 verbose : `bool`, optional 1087 If True, a progress bar is shown while downloading the file. 1088 1089 The default is True. 1090 """ 1091 folder_path = str(Path(download_path).parent.absolute()) 1092 create_path_if_not_exist(folder_path) 1093 1094 if not os.path.isfile(download_path): 1095 session = requests.Session() 1096 1097 response = session.get(url) 1098 html = response.text 1099 1100 def extract(pattern): 1101 match = re.search(pattern, html) 1102 return match.group(1) if match else None 1103 1104 file_id = extract(r'name="id" value="([^"]+)"') 1105 file_confirm = extract(r'name="confirm" value="([^"]+)"') 1106 file_uuid = extract(r'name="uuid" value="([^"]+)"') 1107 1108 if not all([file_id, file_confirm, file_uuid]): 1109 raise SystemError("Failed to extract download parameters") 1110 1111 download_url = ( 1112 f"https://drive.usercontent.google.com/download" 1113 f"?id={file_id}&export=download&confirm={file_confirm}&uuid={file_uuid}" 1114 ) 1115 1116 response = session.get(download_url, stream=True) 1117 1118 if response.status_code != 200: 1119 raise SystemError(f"Failed to download -- {response.status_code}") 1120 1121 if verbose is True: 1122 content_length = int(response.headers.get('content-length', 0)) 1123 with open(download_path, "wb") as file, tqdm(desc=download_path, 1124 total=content_length, 1125 ascii=True, 1126 unit='B', 1127 unit_scale=True, 1128 unit_divisor=1024) as progress_bar: 1129 for data in response.iter_content(chunk_size=1024): 1130 size = file.write(data) 1131 progress_bar.update(size) 1132 else: 1133 with open(download_path, "wb") as f_out: 1134 f_out.write(response.content)
1135 1136
[docs] 1137def download_if_necessary(download_path: str, url: str, verbose: bool = True, 1138 backup_urls: list[str] = [], timeout: int = 30) -> None: 1139 """ 1140 Downloads a file from a given URL if it does not already exist in a given 1141 path. This function is deprecated, please use `robust_download` instead. 1142 1143 Note that if the path (folder) does not already exist, it will be created. 1144 1145 Parameters 1146 ---------- 1147 download_path : `str` 1148 Local path to the file -- if this path does not exist, the file will be 1149 downloaded from the provided 'url' and stored in 'download_dir'. 1150 url : `str` 1151 Web-URL. 1152 verbose : `bool`, optional 1153 If True, a progress bar is shown while downloading the file. 1154 1155 The default is True. 1156 backup_urls : `list[str]`, optional 1157 List of alternative URLs that are being tried in the case that downloading from 'url' fails. 1158 1159 The default is an empty list. 1160 timeout : `int`, optional 1161 Allowed download inactivity in seconds. After this time passed without 1162 an update, the donwload is considered failed. 1163 1164 The default is 30 seconds. 1165 """ 1166 folder_path = str(Path(download_path).parent.absolute()) 1167 create_path_if_not_exist(folder_path) 1168 1169 if os.path.isfile(download_path): 1170 return 1171 1172 last_update = mp.Value('d', time.time()) 1173 stop_flag = mp.Value('i', 0) 1174 finish_flag = mp.Value('i', 0) 1175 1176 t = mp.Process(target=_download_process, args=(download_path, url, backup_urls, last_update, stop_flag, finish_flag, verbose)) 1177 t.start() 1178 1179 while True: 1180 time.sleep(1) 1181 with last_update.get_lock(): 1182 idle = time.time() - last_update.value 1183 with stop_flag.get_lock(): 1184 if stop_flag.value == 1: 1185 with finish_flag.get_lock(): 1186 if finish_flag.value == 1: 1187 break 1188 else: 1189 if os.path.exists(download_path) and finish_flag.value == 0: 1190 os.remove(download_path) 1191 raise SystemError(f"failed downloading from {url}") 1192 if idle > timeout: 1193 with finish_flag.get_lock(): 1194 t.terminate() 1195 t.join() 1196 if os.path.exists(download_path) and finish_flag.value == 0: 1197 os.remove(download_path) 1198 raise SystemError(f"no progress in {timeout} seconds, aborting download") 1199 t.join()
1200 1201
[docs] 1202def create_path_if_not_exist(path_in: str) -> None: 1203 """ 1204 Creates a directory and all its parent directories if they do not already exist. 1205 1206 Parameters 1207 ---------- 1208 path_in : `str` 1209 Path to be created. 1210 """ 1211 Path(path_in).mkdir(parents=True, exist_ok=True)
1212 1213
[docs] 1214def pack_zip_archive(f_in: list[str], f_out: str) -> None: 1215 """ 1216 Compresses a given list of files into a .zip archive. 1217 1218 Parameters 1219 ---------- 1220 f_in : `list[str]` 1221 List of files to be compressed into the .zip archive. 1222 f_out : `str` 1223 Path to the final .zip file. 1224 """ 1225 with zipfile.ZipFile(f_out, "w") as f_zip_out: 1226 for f_cur_in in f_in: 1227 f_zip_out.write(f_cur_in, compress_type=zipfile.ZIP_DEFLATED)
1228 1229
[docs] 1230def unpack_zip_archive(f_in: str, folder_out: str) -> None: 1231 """ 1232 Unpacks a .zip archive. 1233 1234 Parameters 1235 ---------- 1236 f_in : `str` 1237 Path to the .zip file. 1238 folder_out : `str` 1239 Path to the folder where the unpacked files will be stored. 1240 """ 1241 with zipfile.ZipFile(f_in, "r") as f: 1242 f.extractall(folder_out)
1243 1244
[docs] 1245def get_temp_folder() -> str: 1246 """ 1247 Gets a path to a temporary folder -- i.e. a folder for storing temporary files. 1248 1249 Returns 1250 ------- 1251 `str` 1252 Path to a temporary folder. 1253 """ 1254 return tempfile.gettempdir()
1255 1256
[docs] 1257def to_seconds(days: int = None, hours: int = None, minutes: int = None) -> int: 1258 """ 1259 Converts a timestamp (i.e. days, hours, minutes) into seconds. 1260 1261 Parameters 1262 ---------- 1263 days : `int`, optional 1264 Days. 1265 hours : `int`, optional 1266 Hours. 1267 minutes : `int`, optional 1268 Minutes. 1269 1270 Returns 1271 ------- 1272 `int` 1273 Timestamp in seconds. 1274 """ 1275 sec = 0 1276 1277 if days is not None: 1278 if not isinstance(days, int): 1279 raise TypeError(f"'days' must be an instance of 'int' but not of {type(days)}") 1280 if days <= 0: 1281 raise ValueError("'days' must be positive") 1282 1283 sec += 24*60*60 * days 1284 if hours is not None: 1285 if not isinstance(hours, int): 1286 raise TypeError(f"'hours' must be an instance of 'int' but not of {type(hours)}") 1287 if hours <= 0: 1288 raise ValueError("'hours' must be positive") 1289 1290 sec += 60*60 * hours 1291 if minutes is not None: 1292 if not isinstance(minutes, int): 1293 raise TypeError(f"'minutes' must be an instance of 'int' but not of {type(minutes)}") 1294 if minutes <= 0: 1295 raise ValueError("'minutes' must be positive") 1296 1297 sec += 60 * minutes 1298 1299 return sec