1"""
2Module provides helper functions and data management classes for visualizing
3scenarios.
4"""
5import inspect
6from typing import Optional, Union, List, Tuple
7from collections.abc import Iterable
8
9import networkx as nx
10import networkx.drawing.nx_pylab as nxp
11import matplotlib as mpl
12import matplotlib.pyplot as plt
13import numpy as np
14from scipy.interpolate import CubicSpline
15
16from ..serialization import COLOR_SCHEMES_ID, JsonSerializable, serializable
17from ..simulation.scada.scada_data import ScadaData
18
19# Selection of functions for processing visualization data
20stat_funcs = {
21 'mean': np.mean,
22 'min': np.min,
23 'max': np.max
24}
25
26
27def my_draw_networkx_nodes(
28 G,
29 pos,
30 nodelist=None,
31 node_size=300,
32 node_color="#1f78b4",
33 node_shape="o",
34 alpha=None,
35 cmap=None,
36 vmin=None,
37 vmax=None,
38 ax=None,
39 linewidths=None,
40 edgecolors=None,
41 label=None,
42 margins=None,
43 hide_ticks=True,
44):
45 if ax is None:
46 ax = plt.gca()
47
48 if nodelist is None:
49 nodelist = list(G)
50
51 if len(nodelist) == 0: # empty nodelist, no drawing
52 return mpl.collections.PathCollection(None)
53
54 try:
55 xy = np.asarray([pos[v] for v in nodelist])
56 except KeyError as err:
57 raise nx.NetworkXError(f"Node {err} has no position.") from err
58
59 if isinstance(alpha, Iterable):
60 node_color = nxp.apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
61 alpha = None
62
63 if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
64 node_shape = np.array([node_shape for _ in range(len(nodelist))])
65 elif isinstance(node_shape, list):
66 node_shape = np.asarray(node_shape)
67
68 for shape in list(set(node_shape)):
69 node_collection = ax.scatter(
70 xy[node_shape == shape, 0],
71 xy[node_shape == shape, 1],
72 s=node_size,
73 c=node_color,
74 marker=shape,
75 cmap=cmap,
76 vmin=vmin,
77 vmax=vmax,
78 alpha=alpha,
79 linewidths=linewidths,
80 edgecolors=edgecolors,
81 label=label,
82 )
83 if hide_ticks:
84 ax.tick_params(
85 axis="both",
86 which="both",
87 bottom=False,
88 left=False,
89 labelbottom=False,
90 labelleft=False,
91 )
92
93 if margins is not None:
94 if isinstance(margins, Iterable):
95 ax.margins(*margins)
96 else:
97 ax.margins(margins)
98
99 node_collection.set_zorder(2)
100 return node_collection
101
102
[docs]
103class JunctionObject:
104 """
105 Represents a junction component (e.g. nodes, tanks, reservoirs, ...) in a
106 water distribution network and manages all relevant attributes for drawing.
107
108 Attributes
109 ----------
110 nodelist : `list`
111 List of all nodes in WDN pertaining to this component type.
112 pos : `dict`
113 A dictionary mapping nodes to their coordinates in the correct format
114 for drawing.
115 node_shape : :class:`matplotlib.path.Path` or None
116 A shape representing the object, if none, the networkx default circle
117 is used.
118 node_size : `int`, default = 10
119 The size of each node.
120 node_color : `str` or `list`, default = 'k'
121 If `string`: the color for all nodes, if `list`: a list of lists
122 containing a numerical value for each node per frame, which will be
123 used for coloring.
124 interpolated : `bool`, default = False
125 Set to True, if node_colors are interpolated for smoother animation.
126 """
127 def __init__(self, nodelist: list, pos: dict, node_shape: mpl.path.Path = None,
128 node_size: int = 10, node_color: Union[str, list] = 'k',
129 interpolated: bool = False):
130 self.nodelist = nodelist
131 self.pos = pos
132 self.node_shape = node_shape
133 self.node_size = node_size
134 self.node_color = node_color
135 self.interpolated = interpolated
136
[docs]
137 def add_frame(self, statistic: str, values: np.ndarray,
138 pit: int, intervals: Union[int, List[Union[int, float]]]):
139 """
140 Adds a new frame of node_color based on a given statistic.
141
142 Parameters
143 ----------
144 statistic : `str`
145 The statistic to calculate for the data. Can be 'mean', 'min',
146 'max' or 'time_step'.
147 values : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
148 The node values over time as extracted from the scada data.
149 pit : `int`
150 The point in time for the 'time_step' statistic.
151 intervals : `int`, `list[int]` or `list[float]`
152 If provided, the data will be grouped into intervals. It can be an
153 integer specifying the number of groups or a list of boundary
154 points.
155
156 Raises
157 ------
158 ValueError
159 If interval, pit or statistic is not correctly provided.
160
161 """
162 if statistic in stat_funcs:
163 stat_values = stat_funcs[statistic](values, axis=0)
164 elif statistic == 'time_step':
165 if not pit and pit != 0:
166 raise ValueError(
167 'Please input point in time (pit) parameter when selecting'
168 ' time_step statistic')
169 stat_values = np.take(values, pit, axis=0)
170 else:
171 raise ValueError(
172 'Statistic parameter must be mean, min, max or time_step')
173
174 if intervals is None:
175 pass
176 elif isinstance(intervals, (int, float)):
177 interv = np.linspace(stat_values.min(), stat_values.max(),
178 intervals + 1)
179 stat_values = np.digitize(stat_values, interv) - 1
180 elif isinstance(intervals, list):
181 stat_values = np.digitize(stat_values, intervals) - 1
182 else:
183 raise ValueError(
184 'Intervals must be either number of groups or list of interval'
185 ' boundary points')
186
187 sorted_values = [v for _, v in zip(self.nodelist, stat_values)]
188
189 if isinstance(self.node_color, str):
190 # First run of this method
191 self.node_color = []
192 self.vmin = min(sorted_values)
193 self.vmax = max(sorted_values)
194
195 self.node_color.append(sorted_values)
196 self.vmin = min(*sorted_values, self.vmin)
197 self.vmax = max(*sorted_values, self.vmax)
198
[docs]
199 def get_frame(self, frame_number: int = 0):
200 """
201 Returns all attributes necessary for networkx to draw the specified
202 frame.
203
204 Parameters
205 ----------
206 frame_number : `int`, default = 0
207 The frame whose parameters should be returned. Default is 0, this
208 is also used if only 1 frame exists (e.g. for plots, not
209 animations).
210
211 Returns
212 -------
213 valid_params : `dict`
214 A dictionary containing all attributes that function as parameters
215 for `networkx.drawing.nx_pylab.draw_networkx_nodes() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_nodes.html#draw-networkx-nodes>`_.
216 """
217
218 attributes = vars(self).copy()
219
220 if not isinstance(self.node_color, str):
221 if self.interpolated:
222 if frame_number > len(self.node_color_inter):
223 frame_number = -1
224 attributes['node_color'] = self.node_color_inter[frame_number]
225 else:
226 if frame_number > len(self.node_color):
227 frame_number = -1
228 attributes['node_color'] = self.node_color[frame_number]
229
230 sig = inspect.signature(nxp.draw_networkx_nodes)
231
232 valid_params = {
233 key: value for key, value in attributes.items()
234 if key in sig.parameters and value is not None
235 }
236
237 return valid_params
238
[docs]
239 def get_frame_mask(self, mask, color):
240 """
241 Returns all attributes necessary for networkx to draw the specified
242 frame mask. Meaning covering all masked junction objects with the
243 default value.
244
245 Parameters
246 ----------
247 mask: `np.ndarray`
248 An array consisting of 0 and 1, where 0 means no sensor. Nodes
249 without sensor are to be masked.
250 color:
251 The default color of masked nodes.
252
253 Returns
254 -------
255 valid_params : `dict`
256 A dictionary containing all attributes that function as parameters
257 for `networkx.drawing.nx_pylab.draw_networkx_nodes() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_nodes.html#draw-networkx-nodes>`_.
258 """
259
260 attributes = vars(self).copy()
261
262 attributes['nodelist'] = [node for node, flag in
263 zip(self.nodelist, mask) if not flag]
264 attributes['node_color'] = color
265
266 sig = inspect.signature(nxp.draw_networkx_nodes)
267
268 valid_params = {
269 key: value for key, value in attributes.items()
270 if key in sig.parameters and key not in ['vmin', 'vmax', 'cmap']
271 and value is not None
272 }
273
274 return valid_params
275
[docs]
276 def interpolate(self, num_inter_frames: int):
277 """
278 Interpolates node_color values for smoother animations.
279
280 Parameters
281 ----------
282 num_inter_frames : `int`
283 The number of total frames after interpolation.
284 """
285 if isinstance(self.node_color, str) or len(self.node_color) <= 1:
286 return
287
288 tmp_node_color = np.array(self.node_color)
289 steps, num_nodes = tmp_node_color.shape
290
291 x_axis = np.linspace(0, steps - 1, steps)
292 new_x_axis = np.linspace(0, steps - 1, num_inter_frames)
293
294 self.node_color_inter = np.zeros(((len(new_x_axis)), num_nodes))
295
296 for node in range(num_nodes):
297 cs = CubicSpline(x_axis, tmp_node_color[:, node])
298 self.node_color_inter[:, node] = cs(new_x_axis)
299
300 self.interpolated = True
301
[docs]
302 def add_attributes(self, attributes: dict):
303 """
304 Adds the given attributes dict as class attributes.
305
306 Parameters
307 ----------
308 attributes : `dict`
309 Attributes dict, which is to be added as class attributes.
310 """
311 for key, value in attributes.items():
312 setattr(self, key, value)
313
314
[docs]
315class EdgeObject:
316 """
317 Represents an edge component (pipes) in a water distribution network and
318 manages all relevant attributes for drawing.
319
320 Attributes
321 ----------
322 edgelist : `list`
323 List of all edges in WDN pertaining to this component type.
324 pos : `dict`
325 A dictionary mapping pipes to their coordinates in the correct format
326 for drawing.
327 edge_color : `str` or `list`, default = 'k'
328 If `string`: the color for all edges, if `list`: a list of lists
329 containing a numerical value for each edge per frame, which will be
330 used for coloring.
331 interpolated : `dict`, default = {}
332 Filled with interpolated frames if interpolation method is called.
333 """
334 def __init__(self, edgelist: list, pos: dict, edge_color: Union[str, list] = 'k',
335 interpolated: dict = {}):
336 self.edgelist = edgelist
337 self.pos = pos
338 self.edge_color = edge_color
339 self.interpolated = interpolated
340
[docs]
341 def rescale_widths(self, line_widths: Tuple[int, int] = (1, 2)):
342 """
343 Rescales all edge widths to the given interval.
344
345 Parameters
346 ----------
347 line_widths : `Tuple[int]`, default = (1, 2)
348 Min and max value, to which the edge widths should be scaled.
349
350 Raises
351 ------
352 AttributeError
353 If no edge width attribute exists yet.
354 """
355 if not hasattr(self, 'width'):
356 raise AttributeError(
357 'Please call add_frame with edge_param=width before rescaling'
358 ' the widths.')
359
360 vmin = min(min(l) for l in self.width)
361 vmax = max(max(l) for l in self.width)
362
363 tmp = []
364 for il in self.width:
365 tmp.append(
366 self.__rescale(il, line_widths, values_min_max=(vmin, vmax)))
367 self.width = tmp
368
[docs]
369 def add_frame(
370 self, topology, edge_param: str,
371 scada_data: Optional[ScadaData],
372 parameter: str = 'flow_rate', statistic: str = 'mean',
373 pit: Optional[Union[int, Tuple[int]]] = None,
374 species: str = None,
375 intervals: Optional[Union[int, List[Union[int, float]]]] = None,
376 use_sensor_data: bool = None):
377 """
378 Adds a new frame of edge_color or edge width based on the given data
379 and statistic.
380
381 Parameters
382 ----------
383 topology : :class:`~epyt_flow.topology.NetworkTopology`
384 Topology object retrieved from the scenario, containing the
385 structure of the water distribution network.
386 edge_param : `str`
387 Method can be called with edge_width or edge_color to calculate
388 either the width or color for the next frame.
389 scada_data : :class:`~epyt_flow.simulation.scada.scada_data.ScadaData`
390 SCADA data created by the :class:`~epyt_flow.simulation.scenario_simulator.ScenarioSimulator`
391 instance, is used to retrieve data for the next frame.
392 parameter : `str`, default = 'flow_rate'
393 The link data to visualize. Options are 'flow_rate', 'link_quality',
394 'custom_data', 'bulk_species_concentration' or 'diameter'.
395 Default is 'flow_rate'.
396 statistic : `str`, default = 'mean'
397 The statistic to calculate for the data. Can be 'mean', 'min',
398 'max' or 'time_step'.
399 pit : `int`
400 The point in time for the 'time_step' statistic.
401 species: `str`, optional
402 Key of the species. Necessary only for parameter
403 'bulk_species_concentration'.
404 intervals : `int`, `list[int]` or `list[float]`
405 If provided, the data will be grouped into intervals. It can be an
406 integer specifying the number of groups or a list of boundary
407 points.
408 use_sensor_data : `bool`, optional
409 If `True`, instead of using raw simulation data, the data recorded
410 by the corresponding sensors in the system is used for the
411 visualization. Note: Not all components may have a sensor attached
412 and sensors may be subject to sensor faults or noise.
413
414 Raises
415 ------
416 ValueError
417 If parameter, interval, pit or statistic is not set correctly.
418 """
419 if edge_param == 'edge_width' and not hasattr(self, 'width'):
420 self.width = []
421 elif edge_param == 'edge_color':
422 if isinstance(self.edge_color, str):
423 self.edge_color = []
424 self.edge_vmin = float('inf')
425 self.edge_vmax = float('-inf')
426
427 if parameter == 'flow_rate':
428 if use_sensor_data:
429 values, self.mask = scada_data.get_data_flows_as_edge_features()
430 values = values[:, ::2]
431 self.mask = self.mask[::2]
432 else:
433 values = scada_data.flow_data_raw
434 elif parameter == 'link_quality':
435 if use_sensor_data:
436 values, self.mask = scada_data.get_data_links_quality_as_edge_features()
437 values = values[:, ::2]
438 self.mask = self.mask[::2]
439 else:
440 values = scada_data.link_quality_data_raw
441 elif parameter == 'custom_data':
442 values = scada_data
443 elif parameter == 'bulk_species_concentration':
444 if not species:
445 raise ValueError('Species must be given when using '
446 'bulk_species_concentration')
447 if use_sensor_data:
448 values, self.mask = scada_data.get_data_bulk_species_concentrations_as_edge_features()
449 self.mask = self.mask[::2,
450 scada_data.sensor_config.bulk_species.index(
451 species)]
452 values = values[:, ::2,
453 scada_data.sensor_config.bulk_species.index(species)]
454 else:
455 values = scada_data.bulk_species_link_concentration_raw[:,
456 scada_data.sensor_config.bulk_species.index(species),
457 :]
458 elif parameter == 'diameter':
459 value_dict = {
460 link[0]: topology.get_link_info(link[0])['diameter'] for
461 link in topology.get_all_links()}
462 sorted_values = [value_dict[x[0]] for x in
463 topology.get_all_links()]
464
465 if edge_param == 'edge_width':
466 self.width.append(sorted_values)
467 else:
468 self.edge_color.append(sorted_values)
469 self.edge_vmin = min(*sorted_values, self.edge_vmin)
470 self.edge_vmax = max(*sorted_values, self.edge_vmax)
471 return
472 else:
473 raise ValueError('Parameter must be flow_rate, link_quality, '
474 'diameter or custom_data.')
475
476 if statistic in stat_funcs:
477 stat_values = stat_funcs[statistic](values, axis=0)
478 elif statistic == 'time_step':
479 if not pit and pit != 0:
480 raise ValueError(
481 'Please input point in time (pit) parameter when selecting'
482 ' time_step statistic')
483 stat_values = np.take(values, pit, axis=0)
484 else:
485 raise ValueError(
486 'Statistic parameter must be mean, min, max or time_step')
487
488 if intervals is None:
489 pass
490 elif isinstance(intervals, (int, float)):
491 interv = np.linspace(stat_values.min(), stat_values.max(),
492 intervals + 1)
493 stat_values = np.digitize(stat_values, interv) - 1
494 elif isinstance(intervals, list):
495 stat_values = np.digitize(stat_values, intervals) - 1
496 else:
497 raise ValueError(
498 'Intervals must be either number of groups or list of interval'
499 ' boundary points')
500
501 sorted_values = list(stat_values)
502
503 if edge_param == 'edge_width':
504 self.width.append(sorted_values)
505 else:
506 self.edge_color.append(sorted_values)
507 self.edge_vmin = min(*sorted_values, self.edge_vmin)
508 self.edge_vmax = max(*sorted_values, self.edge_vmax)
509
[docs]
510 def get_frame(self, frame_number: int = 0):
511 """
512 Returns all attributes necessary for networkx to draw the specified
513 frame.
514
515 Parameters
516 ----------
517 frame_number : `int`, default = 0
518 The frame whose parameters should be returned. Default is 0, this
519 is also used if only 1 frame exists (e.g. for plots, not
520 animations).
521
522 Returns
523 -------
524 valid_params : `dict`
525 A dictionary containing all attributes that function as parameters
526 for `networkx.drawing.nx_pylab.draw_networkx_edges() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_edges.html#draw-networkx-edges>`_.
527 """
528 attributes = vars(self).copy()
529
530 if not isinstance(self.edge_color, str):
531 if 'edge_color' in self.interpolated.keys():
532 if frame_number > len(self.interpolated['edge_color']):
533 frame_number = -1
534 attributes['edge_color'] = self.interpolated['edge_color'][
535 frame_number]
536 else:
537 if frame_number > len(self.edge_color):
538 frame_number = -1
539 attributes['edge_color'] = self.edge_color[frame_number]
540
541 if hasattr(self, 'width'):
542 if 'width' in self.interpolated.keys():
543 if frame_number > len(self.interpolated['width']):
544 frame_number = -1
545 attributes['width'] = self.interpolated['width'][frame_number]
546 else:
547 if frame_number > len(self.width):
548 frame_number = -1
549 attributes['width'] = self.width[frame_number]
550
551 sig = inspect.signature(nxp.draw_networkx_edges)
552
553 valid_params = {
554 key: value for key, value in attributes.items()
555 if key in sig.parameters and value is not None
556 }
557
558 return valid_params
559
[docs]
560 def get_frame_mask(self, frame_number: int = 0, color='k'):
561 """
562 Returns all attributes necessary for networkx to draw the specified
563 frame mask.
564
565 Parameters
566 ----------
567 frame_number : `int`, default = 0
568 The frame whose parameters should be returned. Default is 0, this
569 is also used if only 1 frame exists (e.g. for plots, not
570 animations).
571 color:
572 The default color of masked nodes.
573
574 Returns
575 -------
576 valid_params : `dict`
577 A dictionary containing all attributes that function as parameters
578 for `networkx.drawing.nx_pylab.draw_networkx_edges() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_edges.html#draw-networkx-edges>`_.
579 """
580 attributes = vars(self).copy()
581
582 attributes['edgelist'] = [edge for edge, flag in
583 zip(self.edgelist, self.mask) if not flag]
584 attributes['edge_color'] = color
585
586 if hasattr(self, 'width'):
587 if 'width' in self.interpolated.keys():
588 if frame_number > len(self.interpolated['width']):
589 frame_number = -1
590 attributes['width'] = self.interpolated['width'][frame_number]
591 else:
592 if frame_number > len(self.width):
593 frame_number = -1
594 attributes['width'] = self.width[frame_number]
595 attributes['width'] = [edge for edge, flag in
596 zip(attributes['width'].copy(), self.mask)
597 if not flag]
598
599 sig = inspect.signature(nxp.draw_networkx_edges)
600
601 valid_params = {
602 key: value for key, value in attributes.items()
603 if key in sig.parameters and key not in ['vmin', 'vmax', 'cmap']
604 and value is not None
605 }
606
607 return valid_params
608
[docs]
609 def interpolate(self, num_inter_frames: int):
610 """
611 Interpolates edge_color and width values for smoother animations.
612
613 Parameters
614 ----------
615 num_inter_frames : `int`
616 The number of total frames after interpolation.
617 """
618 targets = {'edge_color': self.edge_color}
619 if hasattr(self, 'width'):
620 targets['width'] = self.width
621
622 for name, inter_target in targets.items():
623 if isinstance(inter_target, str) or len(inter_target) <= 1:
624 continue
625
626 tmp_target = np.array(inter_target)
627 steps, num_edges = tmp_target.shape
628
629 x_axis = np.linspace(0, steps - 1, steps)
630 new_x_axis = np.linspace(0, steps - 1, num_inter_frames)
631
632 vals_inter = np.zeros(((len(new_x_axis)), num_edges))
633
634 for edge in range(num_edges):
635 cs = CubicSpline(x_axis, tmp_target[:, edge])
636 vals_inter[:, edge] = cs(new_x_axis)
637
638 self.interpolated[name] = vals_inter
639
[docs]
640 def add_attributes(self, attributes: dict):
641 """
642 Adds the given attributes dict as class attributes.
643
644 Parameters
645 ----------
646 attributes : `dict`
647 Attributes dict, which is to be added as class attributes.
648 """
649 for key, value in attributes.items():
650 setattr(self, key, value)
651
652 def __rescale(self, values: Union[np.ndarray, list],
653 scale_min_max: Union[List, Tuple[int]],
654 values_min_max: Union[
655 List, Tuple[int, int]] = None) -> np.ndarray:
656 """
657 Rescales the given values to a new range.
658
659 This method rescales an array of values to fit within a specified
660 minimum and maximum scale range. Optionally, the minimum and maximum
661 of the input values can be manually provided; otherwise, they are
662 automatically determined from the data.
663
664 Parameters
665 ----------
666 values : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ or `list`
667 The array of numerical values to be rescaled.
668 scale_min_max : `list` or `tuple`
669 A list or tuple containing two elements: the minimum and maximum
670 values of the desired output range.
671 values_min_max : `list` or `tuple`, optional
672 A list or tuple containing two elements: the minimum and maximum
673 values of the input data. If not provided, they are computed from
674 the input `values`. Default is `None`.
675
676 Returns
677 -------
678 rescaled_values : `list`
679 A list of values rescaled to the range specified by
680 `scale_min_max`.
681 """
682 if not values_min_max:
683 min_val, max_val = min(values), max(values)
684 else:
685 min_val, max_val = values_min_max
686 scale = scale_min_max[1] - scale_min_max[0]
687
688 def range_map(x):
689 return scale_min_max[0] + (x - min_val) / (
690 max_val - min_val) * scale
691
692 vectorized_range_map = np.vectorize(range_map)
693 rescaled_widths = vectorized_range_map(np.array(values))
694
695 if hasattr(self, 'mask'):
696 rescaled_widths = np.where(self.mask == 1, rescaled_widths, 1.0)
697
698 return rescaled_widths
699
700
[docs]
701@serializable(COLOR_SCHEMES_ID, ".epyt_flow_color_scheme")
702class ColorScheme(JsonSerializable):
703 """
704 A class containing the color scheme for the
705 :class:`~epyt_flow.visualization.ScenarioVisualizer`.
706 """
707
708 def __init__(self, pipe_color: str, node_color: str, pump_color: str,
709 tank_color: str, reservoir_color: str,
710 valve_color: str) -> None:
711 """Initializes the ColorScheme class with the given component colors.
712
713 Accepted formats are the string representations accepted by matplotlib:
714 https://matplotlib.org/stable/users/explain/colors/colors.html#color-formats
715
716 Parameters
717 ----------
718 pipe_color : str
719 String color format accepted by matplotlib.
720 node_color : str
721 String color format accepted by matplotlib.
722 pump_color : str
723 String color format accepted by matplotlib.
724 tank_color : str
725 String color format accepted by matplotlib.
726 reservoir_color : str
727 String color format accepted by matplotlib.
728 valve_color : str
729 String color format accepted by matplotlib.
730 """
731 self.pipe_color = pipe_color
732 self.node_color = node_color
733 self.pump_color = pump_color
734 self.tank_color = tank_color
735 self.reservoir_color = reservoir_color
736 self.valve_color = valve_color
737 super().__init__()
738
[docs]
739 def get_attributes(self) -> dict:
740 """
741 Gets all attributes needed for serialization.
742
743 Returns
744 -------
745 attr : A dictionary containing all attributes to be serialized.
746 """
747 attr = {
748 k: v for k, v in self.__dict__.items()
749 if
750 not (k.startswith("__") or k.startswith("_")) and not callable(v)
751 }
752 return super().get_attributes() | attr
753
754 def __eq__(self, other: any) -> bool:
755 """
756 Checks if two ColorScheme instances are equal.
757
758 Parameters
759 ----------
760 other : :class:`~epyt_flow.visualization_utils.ColorScheme`
761 The other ColorScheme instance to compare this one with.
762
763 Returns
764 -------
765 bool
766 True if all attributes are the same, False otherwise.
767 """
768 if not isinstance(other, ColorScheme):
769 return False
770 return (
771 self.pipe_color == other.pipe_color and
772 self.node_color == other.node_color and
773 self.pump_color == other.pump_color and
774 self.tank_color == other.tank_color and
775 self.reservoir_color == other.reservoir_color and
776 self.valve_color == other.valve_color
777 )
778
779 def __str__(self) -> str:
780 """
781 Returns a string representation of the ColorScheme instance.
782
783 Returns
784 -------
785 str
786 A string describing the ColorScheme instance.
787 """
788 return (f"ColorScheme(pipe_color={self.pipe_color}, "
789 f"node_color={self.node_color}, "
790 f"pump_color={self.pump_color}, "
791 f"tank_color={self.tank_color}, "
792 f"reservoir_color={self.reservoir_color}, "
793 f"valve_color={self.valve_color})")
794
795
796epanet_colors = ColorScheme(
797 pipe_color="#0403ee",
798 node_color="#0403ee",
799 pump_color="#fe00ff",
800 tank_color="#02fffd",
801 reservoir_color="#00ff00",
802 valve_color="#000000"
803)
804
805epyt_flow_colors = ColorScheme(
806 pipe_color="#29222f",
807 node_color="#29222f",
808 pump_color="#d79233",
809 tank_color="#607b80",
810 reservoir_color="#33483d",
811 valve_color="#a3320b"
812)
813
814black_colors = ColorScheme(
815 pipe_color="#000000",
816 node_color="#000000",
817 pump_color="#000000",
818 tank_color="#000000",
819 reservoir_color="#000000",
820 valve_color="#000000"
821)