Skip to content
Snippets Groups Projects
Verified Commit 3c305f8a authored by Marco Aceti's avatar Marco Aceti
Browse files

Merge branch 'feature/timetables'

parents d09196ae 2f9b18c2
No related branches found
No related tags found
No related merge requests found
......@@ -82,6 +82,10 @@ $ python main.py ...
`$ python main.py analyze [..]/stations.csv [..]/trains.csv --stat trajectories_map`
- __Display a timetable graph__.
`$ python main.py analyze [..]/stations.csv [..]/trains.csv --stat timetable --timetable-collapse`
## Fields
### Stations CSV
......
......@@ -61,3 +61,22 @@ def railway_company_filter(
s.strip().lower() for s in railway_companies.strip().split(",") if len(s) > 0
]
return df.loc[df.client_code.str.lower().isin(code_list)]
def railway_lines_filter(df: pd.DataFrame, lines: str | None):
"""Filter dataframe by the railway line.
Args:
df (pd.DataFrame): the considered dataframe
line (str | None): a comma-separated list of railway lines
Returns:
pd.DataFrame: the filtered dataframe
"""
if not lines or len(lines) < 1:
return df
line_list: list[str] = [
l.strip().upper() for l in lines.strip().split(",") if len(l) > 0
]
return df.loc[df.line.isin(line_list)]
......@@ -26,8 +26,8 @@ from dateparser import parse
from joblib import Parallel, delayed
from pandas.core.groupby.generic import DataFrameGroupBy
from src.analysis import groupby, stat, trajectories_map
from src.analysis.filter import date_filter, railway_company_filter
from src.analysis import groupby, stat, timetable, trajectories_map
from src.analysis.filter import *
from src.analysis.load_data import read_station_csv, read_train_csv, tag_lines
......@@ -45,6 +45,15 @@ def register_args(parser: argparse.ArgumentParser):
help="comma-separated list of railway companies to include. If not set, all companies will be included.",
dest="client_codes",
)
parser.add_argument(
"--railway-lines",
help=(
"comma-separated list of railway lines to include. "
"If not set, all lines will be include. "
"Use --stat detect_lines to see available lines."
),
dest="railway_lines",
)
parser.add_argument(
"--group-by",
help="group by stops by a value",
......@@ -75,9 +84,16 @@ def register_args(parser: argparse.ArgumentParser):
"day_train_count",
"trajectories_map",
"detect_lines",
"timetable",
),
default="describe",
)
parser.add_argument(
"--timetable-collapse",
help="collapse the train stop times in the graph, relative to the first (only for 'timetable' stat). Defaults to False.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"station_csv",
help="exported station CSV",
......@@ -110,6 +126,7 @@ def main(args: argparse.Namespace):
raise argparse.ArgumentTypeError("invalid end_date")
railway_companies: str | None = args.client_codes
railway_lines: str | None = args.railway_lines
# Load dataset
df: pd.DataFrame | DataFrameGroupBy = pd.DataFrame()
......@@ -125,14 +142,15 @@ def main(args: argparse.Namespace):
stations: pd.DataFrame = read_station_csv(args.station_csv)
original_length: int = len(df)
# Tag lines
df = tag_lines(df, stations)
# Apply filters
df = date_filter(df, start_date, end_date)
df = railway_company_filter(df, railway_companies)
df = railway_lines_filter(df, railway_lines)
logging.info(f"Loaded {len(df)} data points ({original_length} before filtering)")
# Tag lines
df = tag_lines(df, stations)
# Prepare graphics
stat.prepare_mpl(df, args)
......@@ -161,11 +179,23 @@ def main(args: argparse.Namespace):
stat.delay_boxplot(df)
elif args.stat == "day_train_count":
stat.day_train_count(df)
elif args.stat == "trajectories_map":
if not isinstance(df, pd.DataFrame):
raise ValueError("can't use trajectories_map with unaggregated data")
if args.stat in [
"trajectories_map",
"detect_lines",
"timetable",
] and not isinstance(df, pd.DataFrame):
raise ValueError(f"can't use {args.stat} with unaggregated data")
assert isinstance(df, pd.DataFrame)
if args.stat == "trajectories_map":
trajectories_map.build_map(stations, df)
elif args.stat == "detect_lines":
if not isinstance(df, pd.DataFrame):
raise ValueError("can't use detect_lines with unaggregated data")
stat.detect_lines(df, stations)
elif args.stat == "timetable":
if not timetable.same_line(df):
raise ValueError(
f"can't use timetable if --railway-lines filter is not used"
)
timetable.timetable_graph(df, stations, args.timetable_collapse)
# railway-opendata: scrape and analyze italian railway data
# Copyright (C) 2023 Marco Aceti
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
import timple
from src.const import TIMEZONE, TIMEZONE_GMT
def same_line(df: pd.DataFrame) -> bool:
"""Check if the trains in the provided DataFrame are ALL on the same line
Args:
df (pd.DataFrame): the trains to check
Return:
bool: True if the trains are all on the same line, False otherwise
"""
return df.line.nunique() <= 1
def timetable_train(train: pd.DataFrame, expected: bool = False, collapse: bool = True):
"""Generate a timetable graph of a train
Args:
train (pd.DataFrame): the train stop data to consider
expected (bool, optional): determines whatever to consider the 'expected' or 'actual' arrival/departure times. Defaults to False.
collapse (bool, optional): determines whatever to _collapse_ the times in the graph, relative to the first. Defaults to True.
"""
if collapse:
train.value -= train.value.min()
train_f = train.loc[
train.variable.str.endswith("expected" if expected else "actual")
]
plt.plot(
train_f.value,
train_f.long_name,
"ko" if expected else "o",
linestyle="-" if expected else "--",
linewidth=3 if expected else 2,
label=f"{train.iloc[0].category} {train.iloc[0].number}"
if not expected
else "expected",
zorder=10 if expected else 5,
)
def timetable_graph(trains: pd.DataFrame, st: pd.DataFrame, collapse: bool = True):
"""Generate a timetable graph of trains in a line.
Args:
trains (pd.DataFrame): the train stop data to consider
st (pd.DataFrame): the station data
collapse (bool, optional): determines whatever to _collapse_ the times in the graph, relative to the first. Defaults to True.
"""
tmpl = timple.Timple()
tmpl.enable()
trains_j = (
trains.sort_values(by="stop_number")
.join(st, on="stop_station_code")
.reset_index(drop=True)
)
trains_m = (
pd.melt(
trains_j,
id_vars=[
"long_name",
"stop_number",
"train_hash",
"category",
"number",
"origin",
],
value_vars=[
"departure_expected",
"departure_actual",
"arrival_expected",
"arrival_actual",
],
)
.sort_values(["stop_number", "variable"])
.dropna()
)
# expected
if collapse:
for origin in trains_m.origin.unique():
train = list(trains_m.loc[trains_m.origin == origin].groupby("train_hash"))[0][1] # fmt: skip
timetable_train(train, True)
# actual
for _, train in trains_m.groupby("train_hash"):
timetable_train(train, False, collapse)
# get station names for proper title
st_names: pd.DataFrame = st.drop(
["region", "latitude", "longitude", "short_name"],
axis=1,
)
line: pd.DataFrame = (
trains.join(st_names, on="origin")
.rename({"long_name": "station_a"}, axis=1)
.join(st_names, on="destination")
.rename({"long_name": "station_b"}, axis=1)
)[["station_a", "station_b", "stop_number"]].agg(
{
"station_a": lambda s: s.iloc[0],
"station_b": lambda s: s.iloc[0],
"stop_number": lambda n: max(n) + 1,
}
)
plt.title(f"{line.station_a}{line.station_b} [{line.stop_number} stops]")
start_day, end_day = trains.day.min().date(), trains.day.max().date()
plt.title(f"{start_day} => {end_day}", loc="left")
plt.ylabel("Station")
plt.xlabel("Time")
ax = plt.gca()
ax.invert_yaxis()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M", TIMEZONE if not collapse else TIMEZONE_GMT)) # type: ignore
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment