Source code for pyxplor.plot_numeric

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math

[docs] def plot_numeric( input_df: pd.DataFrame, list_of_variables: list, plot_kind: str, figsize: tuple = (8, 10), output: bool = False, super_title: str = "Distribution of Numeric Variables", super_title_font: int = 14, color: str = "lightblue", mean_color: str = "orange", median_color: str = "red")-> None: """Plot the distribution of numeric variables in a DataFrame, save the plot, and display it. This function will construct a set of subplots for each of the numeric variables specified in `list_of_variables`. The function will save as well as display the entire figure. Parameters ---------- input_df : pd.DataFrame The DataFrame that contains the numeric variables. list_of_variables : list List of numeric variables (column names) to be plotted. plot_kind : {'hist', 'kde', 'hist+kde'} Type of plot to be generated: 'hist': Histogram only. 'kde': Kernel density estimate plot only. 'hist+kde': Histogram with a kernel density estimate plot showing median and mean. figsize : tuple[int, int], optional The width and height of the figure size in a tuple. Default is (8, 10). output : bool, optional Whether to output the figure to the current working directory. Default is False. super_title : str, optional Super title for the entire plot. Default is "Distribution of Numeric Variables". super_title_font : int, optional Font size for the super title. Default is 14. color : str, optional Color for the plot elements. If None, the default color palette will be used. mean_color : str, optional Color for the mean line. If None, the default color will be used. median_color : str, optional Color for the median line. If None, the default color will be used. Returns ------- fig : matplotlib.figure.Figure The matplotlib Figure object. ax : matplotlib.axes.Axes or array of Axes The matplotlib Axes object(s). Examples -------- numeric = ["numeric_var1", "numeric_var2"] plot_numeric(df, numeric, "hist+kde") """ # Validate input_df if not isinstance(input_df, pd.DataFrame): raise ValueError("input_df must be a pandas DataFrame.") # Validate list_of_variables if not all(var in input_df.columns for var in list_of_variables): missing_vars = [var for var in list_of_variables if var not in input_df.columns] raise ValueError(f"The following variables are not present in the DataFrame: {', '.join(missing_vars)}.") if len(list_of_variables) == 0: raise ValueError("list_of_variables cannot be an empty list.") # Validate plot_kind valid_plot_kinds = {'hist', 'kde', 'hist+kde'} if plot_kind not in valid_plot_kinds: raise ValueError("Invalid value for 'plot_kind'. It should be either 'hist', 'kde', or 'hist+kde'.") # Validate figsize if not (isinstance(figsize, tuple) and len(figsize) == 2 and all(isinstance(val, (int, float)) for val in figsize)): raise ValueError("figsize must be a tuple of exactly two numbers (integers or floats).") # Validate output if not isinstance(output, bool): raise ValueError("Output must be a boolean value.") # Validate super_title_font if not isinstance(super_title_font, (int, float)): raise ValueError("super_title_font must be a number (integer or float).") # Get numeric columns from the dataframe numeric_columns = input_df.select_dtypes(include=['number']).columns.tolist() # Filter numeric columns based on the provided list_of_variables numeric_columns = [var for var in list_of_variables if var in numeric_columns] # Check if there are any numeric columns to plot if not numeric_columns: raise ValueError("No valid numeric columns found in the provided list_of_variables.") # Calculate the number of rows and columns dynamically total_plots = len(numeric_columns) rows = math.ceil(math.sqrt(total_plots)) cols = math.ceil(total_plots / rows) # Create subplots in a grid fig, ax = plt.subplots(rows, cols, figsize=figsize) if len(list_of_variables) > 1: ax = ax.flatten() # Loop through numeric columns and plot for i, variable in enumerate(numeric_columns): # Plot histogram or kernel density estimate plot for each variable if plot_kind == 'hist': sns.histplot(x=variable, data=input_df, ax=ax[i], bins=20, color=color) elif plot_kind == 'kde': sns.kdeplot(x=variable, data=input_df, ax=ax[i], color=color) elif plot_kind == 'hist+kde': sns.histplot(x=variable, data=input_df, ax=ax[i], bins=20, kde=True, color=color) # Add central tendency labels with different colors mean = input_df[variable].mean() median = input_df[variable].median() mean_label = f'Mean: {mean:.2f}' median_label = f'Median: {median:.2f}' ax[i].axvline(mean, color=mean_color, linestyle='dashed', linewidth=2, label=mean_label) ax[i].axvline(median, color=median_color, linestyle='dashed', linewidth=2, label=median_label) # Add legend outside the subplots ax[i].legend(loc='upper right', bbox_to_anchor=(1.1, 1.05)) ax[i].set_title(variable) # Set subplot title # Add super title fig.suptitle(super_title, fontsize=super_title_font) # Adjust layout and save the figure plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust rect to make space for super title if output: plt.savefig("numeric_variables.png") return fig, ax