Skip to content
Snippets Groups Projects
Verified Commit bdbdc128 authored by Marco Aceti's avatar Marco Aceti
Browse files

Add train loading and basic day filtering

parent ba05e019
Branches
No related tags found
No related merge requests found
from datetime import datetime
import pandas as pd
def date_filter(
df: pd.DataFrame, start_date: datetime | None, end_date: datetime | None
) -> pd.DataFrame:
"""Filter dataframe by date (day).
Args:
df (pd.DataFrame): the considered dataframe
start_date (datetime | None): the start date
end_date (datetime | None): the end date
Returns:
pd.DataFrame: the filtered dataframe
"""
if isinstance(start_date, datetime):
df = df.loc[df.day >= start_date.date()]
if isinstance(end_date, datetime):
df = df.loc[df.day <= end_date.date()]
return df
......@@ -27,6 +27,7 @@ def read_train_csv(file: Path) -> pd.DataFrame:
infer_datetime_format=True,
)
df.client_code = df.client_code.apply(RailwayCompany.from_code) # type: ignore
df.day = df.day.apply(lambda dt: dt.date())
return df.loc[(df.phantom == False) & (df.trenord_phantom == False)].drop(
["phantom", "trenord_phantom"], axis=1
)
......
import argparse
import logging
import pathlib
from datetime import datetime
import pandas as pd
from dateparser import parse
from tqdm import tqdm
from src.analysis.filter import date_filter
from src.analysis.load_data import read_station_csv, read_train_csv
def register_args(parser: argparse.ArgumentParser):
......@@ -37,4 +44,24 @@ def main(args: argparse.Namespace):
if args.end_date and not end_date:
raise argparse.ArgumentTypeError("invalid end_date")
raise NotImplementedError()
# Load dataset
df = pd.DataFrame()
logging.info("Loading datasets...")
for train_csv in (
tqdm(args.trains_csv)
if logging.root.getEffectiveLevel() > logging.DEBUG
else args.trains_csv
):
path = pathlib.Path(train_csv)
train_df: pd.DataFrame = read_train_csv(pathlib.Path(train_csv))
df = pd.concat([df, train_df], axis=0)
logging.debug(f"Loaded {len(train_df)} data points @ {path}")
stations: pd.DataFrame = read_station_csv(args.station_csv)
original_length: int = len(df)
# Apply filters
df = date_filter(df, start_date, end_date)
logging.info(f"Loaded {len(df)} data points ({original_length} before filtering)")
print(df)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment