Skip to content
Snippets Groups Projects
Verified Commit d37a67ed authored by Marco Aceti's avatar Marco Aceti
Browse files

Add detect_lines stat

parent 7181227f
Branches
No related tags found
No related merge requests found
Pipeline #2894 passed
......@@ -14,3 +14,4 @@ joblib==1.2.0
timple==0.1.5
colour==0.1.5
branca==0.6.0
itables==1.5.2
......@@ -74,6 +74,7 @@ def register_args(parser: argparse.ArgumentParser):
"delay_boxplot",
"day_train_count",
"trajectories_map",
"detect_lines",
),
default="describe",
)
......@@ -130,7 +131,7 @@ def main(args: argparse.Namespace):
logging.info(f"Loaded {len(df)} data points ({original_length} before filtering)")
# Tag lines
tag_lines(df, stations)
df = tag_lines(df, stations)
# Prepare graphics
stat.prepare_mpl(df, args)
......@@ -163,5 +164,8 @@ def main(args: argparse.Namespace):
elif args.stat == "trajectories_map":
if not isinstance(df, pd.DataFrame):
raise ValueError("can't use trajectories_map with unaggregated data")
trajectories_map.build_map(stations, df)
elif args.stat == "detect_lines":
if not isinstance(df, pd.DataFrame):
raise ValueError("can't use detect_lines with unaggregated data")
stat.detect_lines(df, stations)
......@@ -16,11 +16,14 @@
import argparse
import webbrowser
from tempfile import NamedTemporaryFile
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from itables import to_html_datatable
from pandas.core.groupby.generic import DataFrameGroupBy
from src.const import RAILWAY_COMPANIES_PALETTE, WEEKDAYS
......@@ -145,3 +148,43 @@ def day_train_count(df: pd.DataFrame | DataFrameGroupBy) -> None:
ax.set(xlabel="Day", ylabel="Train count")
plt.show()
def detect_lines(df: pd.DataFrame, st: pd.DataFrame) -> None:
"""Show a interactive table with the detected (by tag_lines) railway lines"""
st_names: pd.DataFrame = st.drop(
["region", "latitude", "longitude", "short_name"],
axis=1,
)
lines: pd.DataFrame = (
(
df.join(st_names, on="origin")
.rename({"long_name": "station_a"}, axis=1)
.join(st_names, on="destination")
.rename({"long_name": "station_b"}, axis=1)
)[["line", "station_a", "station_b", "train_hash", "stop_number"]]
.groupby("line")
.agg(
{
"station_a": "first",
"station_b": "first",
"train_hash": "nunique",
"stop_number": lambda g: max(g) + 1,
}
)
.rename({"train_hash": "train_count"}, axis=1)
.sort_values(by="train_count", ascending=False)
.reset_index()
)
html: str = to_html_datatable(
lines,
caption="Detected railway lines",
lengthMenu=[20, 50, 100],
order=[3, "desc"],
maxBytes=2**17,
)
outfile = NamedTemporaryFile(delete=False)
outfile.write(html.encode("utf-8"))
webbrowser.open(outfile.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment