Skip to content
Snippets Groups Projects
Verified Commit 4ef0a88f authored by Marco Aceti's avatar Marco Aceti
Browse files

Add day_train_count stat

parent 2b5517cf
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,9 @@ def date_filter(
pd.DataFrame: the filtered dataframe
"""
if isinstance(start_date, datetime):
df = df.loc[df.day >= start_date.date()]
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
df = df.loc[df.day >= start_date]
if isinstance(end_date, datetime):
df = df.loc[df.day <= end_date.date()]
end_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
df = df.loc[df.day <= end_date]
return df
......@@ -41,7 +41,7 @@ def register_args(parser: argparse.ArgumentParser):
"mean",
"last",
),
default="last",
default="none",
)
parser.add_argument(
"--stat",
......@@ -49,6 +49,7 @@ def register_args(parser: argparse.ArgumentParser):
choices=(
"describe",
"delay_boxplot",
"day_train_count",
),
default="describe",
)
......@@ -118,3 +119,5 @@ def main(args: argparse.Namespace):
stat.describe(df)
elif args.stat == "delay_boxplot":
stat.delay_boxplot(df)
elif args.stat == "day_train_count":
stat.day_train_count(df)
......@@ -17,6 +17,7 @@ def set_plot_title(df: pd.DataFrame, args: argparse.Namespace) -> None:
"""Set the plot title based on the cli arguments"""
if args.stat not in [
"delay_boxplot",
"day_train_count",
]:
return
......@@ -36,8 +37,6 @@ def set_plot_title(df: pd.DataFrame, args: argparse.Namespace) -> None:
def delay_boxplot(df: pd.DataFrame | DataFrameGroupBy) -> None:
"""Show a seaborn boxplot of departure and arrival delays"""
sns.set_theme(style="ticks", palette="pastel")
sns.set()
if isinstance(df, DataFrameGroupBy):
grouped_by: str = df.any().index.name
......@@ -82,3 +81,32 @@ def delay_boxplot(df: pd.DataFrame | DataFrameGroupBy) -> None:
plt.grid()
plt.show()
def day_train_count(df: pd.DataFrame | DataFrameGroupBy) -> None:
"""Show a seaborn barplot of unique train count, grouped by day"""
if isinstance(df, DataFrameGroupBy):
grouped_by: str = df.any().index.name
grouped = df.obj.groupby(["day", grouped_by]).nunique().reset_index()
grouped["day"] = grouped["day"].apply(lambda d: d.date().isoformat())
ax = sns.barplot(
data=grouped,
x="day",
y="train_hash",
hue=grouped_by,
)
elif isinstance(df, pd.DataFrame):
grouped = df.groupby("day").nunique().reset_index()
grouped["day"] = grouped["day"].apply(lambda d: d.date().isoformat())
ax = sns.barplot(
data=grouped,
x="day",
y="train_hash",
)
ax.set(xlabel="Day", ylabel="Train count")
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment