Source code for abacus.resplitter.resplit_builder

import pandas as pd
from abacus.resplitter.params import ResplitParams


[docs]class ResplitBuilder: """Builds stratification split for DataFrame.""" def __init__(self, df: pd.DataFrame, resplit_params: ResplitParams): """Makes restratification of dataframe Args: df (pandas.DataFrame): DataFrame for rebuild split. resplit_params (ResplitParams): Params for resplit. """ self.df = df self.params = resplit_params self._len_df = len(df) def collect(self) -> pd.DataFrame: """Method recalculate fractions of each strata in dataframe. Returns: pandas.DataFrame: DataFrame with recalculated strata fractions. """ df_restrata = pd.DataFrame() strata_all_count = self.df.value_counts(self.params.strata_col).reset_index( name="all_count" ) strata_test_counts = self.df.value_counts( [self.params.strata_col, self.params.group_col] ).reset_index(name="count") strata_test_counts = strata_test_counts.merge( strata_all_count, on=self.params.strata_col, how="inner" ) strata_test_counts["frequency"] = ( strata_test_counts["count"] / strata_test_counts["all_count"] ) min_group_freq = ( strata_test_counts.groupby(self.params.strata_col) .min()["frequency"] .reset_index(name="min_freq") ) strata_test_counts = strata_test_counts.merge( min_group_freq, on=self.params.strata_col ) strata_test_counts["disired_count"] = round( strata_test_counts["all_count"] * strata_test_counts["min_freq"], 0 ) strata_test_counts["disired_count"] = strata_test_counts[ "disired_count" ].astype(int) group_names = [ self.params.group_names.test_group_name, self.params.group_names.control_group_name, ] for group_name in group_names: for strata in self.df[self.params.strata_col].unique(): strata_group_count = ( strata_test_counts[ (strata_test_counts[self.params.group_col] == group_name) & (strata_test_counts[self.params.strata_col] == strata) ]["disired_count"] ).values[0] strata_group_values = self.df[ (self.df[self.params.group_col] == group_name) & (self.df[self.params.strata_col] == strata) ] df_strata = strata_group_values.sample(n=strata_group_count) df_restrata = pd.concat([df_restrata, df_strata]) return df_restrata