Skip to content
Snippets Groups Projects
Verified Commit c74145c0 authored by Marco Aceti's avatar Marco Aceti
Browse files

Add --railway-lines filter

parent d37a67ed
Branches develop
No related tags found
No related merge requests found
Pipeline #2896 passed
......@@ -61,3 +61,22 @@ def railway_company_filter(
s.strip().lower() for s in railway_companies.strip().split(",") if len(s) > 0
]
return df.loc[df.client_code.str.lower().isin(code_list)]
def railway_lines_filter(df: pd.DataFrame, lines: str | None):
"""Filter dataframe by the railway line.
Args:
df (pd.DataFrame): the considered dataframe
line (str | None): a comma-separated list of railway lines
Returns:
pd.DataFrame: the filtered dataframe
"""
if not lines or len(lines) < 1:
return df
line_list: list[str] = [
l.strip().upper() for l in lines.strip().split(",") if len(l) > 0
]
return df.loc[df.line.isin(line_list)]
......@@ -27,7 +27,7 @@ from joblib import Parallel, delayed
from pandas.core.groupby.generic import DataFrameGroupBy
from src.analysis import groupby, stat, trajectories_map
from src.analysis.filter import date_filter, railway_company_filter
from src.analysis.filter import *
from src.analysis.load_data import read_station_csv, read_train_csv, tag_lines
......@@ -45,6 +45,15 @@ def register_args(parser: argparse.ArgumentParser):
help="comma-separated list of railway companies to include. If not set, all companies will be included.",
dest="client_codes",
)
parser.add_argument(
"--railway-lines",
help=(
"comma-separated list of railway lines to include. "
"If not set, all lines will be include. "
"Use --stat detect_lines to see available lines."
),
dest="railway_lines",
)
parser.add_argument(
"--group-by",
help="group by stops by a value",
......@@ -110,6 +119,7 @@ def main(args: argparse.Namespace):
raise argparse.ArgumentTypeError("invalid end_date")
railway_companies: str | None = args.client_codes
railway_lines: str | None = args.railway_lines
# Load dataset
df: pd.DataFrame | DataFrameGroupBy = pd.DataFrame()
......@@ -125,14 +135,15 @@ def main(args: argparse.Namespace):
stations: pd.DataFrame = read_station_csv(args.station_csv)
original_length: int = len(df)
# Tag lines
df = tag_lines(df, stations)
# Apply filters
df = date_filter(df, start_date, end_date)
df = railway_company_filter(df, railway_companies)
df = railway_lines_filter(df, railway_lines)
logging.info(f"Loaded {len(df)} data points ({original_length} before filtering)")
# Tag lines
df = tag_lines(df, stations)
# Prepare graphics
stat.prepare_mpl(df, args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment