"""Load combined GTFS and OSM data into a graph."""
import multiprocessing as mp
import os
from functools import partial
import geopandas as gpd
import networkx as nx
import osmnx as ox
import pandas as pd
from shapely.geometry import LineString, Point
from .connectors import _fill_coordinates, connect_stops_to_streets
from .converters import parse_time_to_seconds
from .functions import validate_feed
from .other import logger
[docs]
def _preprocess_schedules(graph: nx.DiGraph):
"""
Sorts the schedules on each edge for faster lookup.
"""
for _, _, data in graph.edges(data=True):
if "schedules" in data:
sorted_schedules = sorted(data["schedules"], key=lambda x: x[0])
data["sorted_schedules"] = sorted_schedules
data["departure_times"] = [schedule[0] for schedule in sorted_schedules]
[docs]
def _add_edge_with_geometry(graph, start_stop, end_stop, schedule_info, geometry):
"""
Adds or updates an edge in the graph with schedule information and geometry.
"""
edge = (start_stop["stop_id"], end_stop["stop_id"])
if graph.has_edge(*edge):
graph[edge[0]][edge[1]]["schedules"].append(schedule_info)
if "geometry" not in graph[edge[0]][edge[1]]:
graph[edge[0]][edge[1]]["geometry"] = geometry
else:
graph.add_edge(
*edge, schedules=[schedule_info], type="transit", geometry=geometry
)
[docs]
def _process_trip_group(
group, graph, trips_df, shapes, trip_to_shape_map, stops_df, read_shapes
):
"""
Processes a group of sorted stops for a single trip, adding edges between them to the graph.
Parameters
----------
group : pd.DataFrame
A group of sorted stops for a single trip.
graph : networkx.DiGraph
The graph to which the edges will be added.
trips_df : pd.DataFrame
DataFrame containing trip information.
shapes : dict
Dictionary mapping shape IDs to shape geometries.
trip_to_shape_map : dict
Dictionary mapping trip IDs to shape IDs.
stops_df : pd.DataFrame
DataFrame containing stop information.
read_shapes : bool
Flag indicating whether to read shape geometries from shapes.txt.
Returns
-------
None
"""
# Mapping stop_id to coordinates for faster lookup
stop_coords_mapping = stops_df.set_index("stop_id")[
["stop_lat", "stop_lon"]
].to_dict("index")
trip_route_mapping = trips_df.set_index("trip_id")["route_id"].to_dict()
# Some GTFS feeds do not have wheelchair_accessible information
if "wheelchair_accessible" in trips_df.columns:
trip_wheelchair_mapping = trips_df.set_index("trip_id")[
"wheelchair_accessible"
].to_dict()
else:
trip_wheelchair_mapping = {}
# For each pair of consecutive stops in the group, add an edge to the graph
for i in range(len(group) - 1):
start_stop, end_stop = group.iloc[i], group.iloc[i + 1]
departure, arrival = (
parse_time_to_seconds(start_stop["departure_time"]),
parse_time_to_seconds(end_stop["arrival_time"]),
)
if departure > arrival:
raise ValueError(
f"Departure time {departure} is greater than arrival time {arrival} for edge {start_stop['stop_id']} -> {end_stop['stop_id']}\n"
"Negative travel time not allowed\n"
"Check the GTFS feed for errors in stop_times.txt or calendar.txt, or adjust the departure time\n"
)
trip_id = start_stop["trip_id"]
route_id = trip_route_mapping.get(trip_id)
wheelchair_accessible = trip_wheelchair_mapping.get(trip_id, None)
schedule_info = (departure, arrival, route_id, wheelchair_accessible)
# If read_shapes is True, use the shape geometry from shapes.txt
if read_shapes:
shape_id = trip_to_shape_map.get(trip_id)
geometry = shapes.get(shape_id)
# Otherwise, use the stop coordinates to create a simple LineString geometry
else:
start_coords, end_coords = (
stop_coords_mapping.get(start_stop["stop_id"]),
stop_coords_mapping.get(end_stop["stop_id"]),
)
geometry = LineString(
[
(start_coords["stop_lon"], start_coords["stop_lat"]),
(end_coords["stop_lon"], end_coords["stop_lat"]),
]
)
_add_edge_with_geometry(
graph=graph,
start_stop=start_stop,
end_stop=end_stop,
schedule_info=schedule_info,
geometry=geometry,
)
[docs]
def _add_edges_parallel(
trips_chunks, graph, trips_df, shapes, read_shapes, trip_to_shape_map, stops_df
):
"""
Adds edges to the graph for chunks of trips in parallel.
"""
local_graph = graph.copy()
for _, group in trips_chunks.groupby(["trip_id"]):
sorted_group = group.sort_values("stop_sequence")
_process_trip_group(
group=sorted_group,
graph=local_graph,
trips_df=trips_df,
shapes=shapes,
trip_to_shape_map=trip_to_shape_map,
stops_df=stops_df,
read_shapes=read_shapes,
)
return local_graph
[docs]
def _filter_stop_times_by_time(
stop_times: pd.DataFrame, departure_time: int, duration_seconds: int
):
"""Filters stop_times to only include trips that occur within a specified time window."""
stop_times["departure_time_seconds"] = stop_times["departure_time"].apply(
parse_time_to_seconds
)
return stop_times[
(stop_times["departure_time_seconds"] >= departure_time)
& (stop_times["departure_time_seconds"] <= departure_time + duration_seconds)
]
[docs]
def _split_dataframe(df: pd.DataFrame, n_splits: int) -> list[pd.DataFrame]:
"""
Splits a DataFrame into n equal parts by rows.
This function replaces np.split_array which will be deprecated soon.
Parameters
----------
df : pandas DataFrame
The DataFrame to be split.
n_splits : int
The number of parts to split the DataFrame into.
Returns
-------
list of pandas DataFrames
A list of DataFrame parts.
"""
total_rows = len(df)
base_size = total_rows // n_splits
remainder = total_rows % n_splits
# Determine the number of rows each split will have
split_sizes = [
base_size + 1 if i < remainder else base_size for i in range(n_splits)
]
# Calculate the start indices for each split
start_indices = [sum(split_sizes[:i]) for i in range(n_splits)]
return [
df.iloc[start : start + size] for start, size in zip(start_indices, split_sizes)
]
[docs]
def _load_GTFS(
GTFSpath: str,
departure_time_input: str,
day_of_week: str,
duration_seconds,
read_shapes=False,
multiprocessing=False,
) -> tuple[nx.DiGraph, pd.DataFrame]:
"""
Loads GTFS data from the specified directory path and returns a graph and a dataframe of stops.
The function uses parallel processing to speed up data loading.
Parameters
----------
GTFSpath : str
Path to the directory containing GTFS data files.
departure_time_input : str
The departure time in 'HH:MM:SS' format.
day_of_week : str
Day of the week in lower case, e.g. "monday".
duration_seconds : int
Duration of the time window to load in seconds.
read_shapes : bool
Geometry reading flag, passed from feed_to_graph.
Returns
-------
tuple
A tuple containing:
- nx.DiGraph: Graph representing GTFS data.
- pd.DataFrame: DataFrame containing stop information.
"""
# Initializing empty graph and read data files.
G = nx.DiGraph()
stops_df = pd.read_csv(
os.path.join(GTFSpath, "stops.txt"), usecols=["stop_id", "stop_lat", "stop_lon"]
)
stop_times_df = pd.read_csv(
os.path.join(GTFSpath, "stop_times.txt"),
usecols=[
"departure_time",
"trip_id",
"stop_id",
"stop_sequence",
"arrival_time",
],
)
routes = pd.read_csv(
os.path.join(GTFSpath, "routes.txt"), usecols=["route_id", "route_short_name"]
)
trips_df = pd.read_csv(os.path.join(GTFSpath, "trips.txt"))
calendar_df = pd.read_csv(os.path.join(GTFSpath, "calendar.txt"))
# Load shapes.txt if read_shapes is True
if read_shapes:
logger.warning("Reading shapes is currently not working as intended")
if "shapes.txt" not in os.listdir(GTFSpath):
raise FileNotFoundError("shapes.txt not found")
shapes_df = pd.read_csv(os.path.join(GTFSpath, "shapes.txt"))
# Group geometry by shape_id, resulting in a Pandas Series
# with trip_id (shape_id ?) as keys and LineString geometries as values
# This is definitely not working as intended
shapes = shapes_df.groupby("shape_id")[["shape_pt_lon", "shape_pt_lat"]].apply(
lambda group: LineString(group.values)
)
# Mapping trip_id to shape_id for faster lookup
trip_to_shape_map = trips_df.set_index("trip_id")["shape_id"].to_dict()
else:
shapes = None
trip_to_shape_map = None
# Join route information to trips
trips_df = trips_df.merge(routes, on="route_id")
# Filter trips by day of the week
service_ids = calendar_df[calendar_df[day_of_week] == 1]["service_id"]
trips_df = trips_df[trips_df["service_id"].isin(service_ids)]
# Filter stop_times by valid trips
valid_trips = stop_times_df["trip_id"].isin(trips_df["trip_id"])
stop_times_df = stop_times_df[valid_trips].dropna()
# Convert departure_time from HH:MM:SS o seconds
departure_time_seconds = parse_time_to_seconds(departure_time_input)
# Filtering stop_times by time window
filtered_stops = _filter_stop_times_by_time(
stop_times_df, departure_time_seconds, duration_seconds
)
print(f"{len(filtered_stops)} of {len(stop_times_df)} trips retained")
# Adding stops as nodes to the graph
for _, stop in stops_df.iterrows():
G.add_node(
stop["stop_id"],
type="transit",
pos=(stop["stop_lon"], stop["stop_lat"]),
x=stop["stop_lon"],
y=stop["stop_lat"],
)
if multiprocessing:
print("Building graph in parallel")
# Divide filtered_stops into chunks for parallel processing
# Use half of the available CPU logical cores
# (likely equal to the number of physical cores)
num_cores = int(mp.cpu_count() / 2) if mp.cpu_count() > 1 else 1
chunks = _split_dataframe(filtered_stops, num_cores)
# Create a pool of processes
with mp.Pool(processes=num_cores) as pool:
# Create a subgraph in each process
# Each will return a graph with edges for a subset of trips
# The results will be combined into a single graph
add_edges_partial = partial(
_add_edges_parallel,
graph=G,
trips_df=trips_df,
shapes=shapes,
read_shapes=read_shapes,
trip_to_shape_map=trip_to_shape_map,
stops_df=stops_df,
)
results = pool.map(add_edges_partial, chunks)
# Merge results from all processes
merged_graph = nx.DiGraph()
for graph in results:
merged_graph.add_nodes_from(graph.nodes(data=True))
# Add edges from subgraphs to the merged graph
for graph in results:
for u, v, data in graph.edges(data=True):
# If edge already exists, merge schedules
if merged_graph.has_edge(u, v):
# Merge sorted_schedules attribute
existing_schedules = merged_graph[u][v]["schedules"]
new_schedules = data["schedules"]
merged_graph[u][v]["schedules"] = existing_schedules + new_schedules
# If edge does not exist, add it
else:
# Add new edge with data
merged_graph.add_edge(u, v, **data)
# Sorting schedules for faster lookup using binary search
_preprocess_schedules(merged_graph)
logger.info("Transit graph created")
return merged_graph, stops_df
else:
for trip_id, group in filtered_stops.groupby("trip_id"):
sorted_group = group.sort_values("stop_sequence")
_process_trip_group(
group=sorted_group,
graph=G,
trips_df=trips_df,
shapes=shapes,
trip_to_shape_map=trip_to_shape_map,
stops_df=stops_df,
read_shapes=read_shapes,
)
# Sorting schedules for faster lookup using binary search
_preprocess_schedules(graph=G)
logger.info("Transit graph created")
return G, stops_df
[docs]
def _load_osm(stops: pd.DataFrame, save_graphml: bool, path) -> nx.DiGraph:
"""
Loads OpenStreetMap data within a convex hull of stops in GTFS feed,
creates a street network graph, and adds walking times as edge weights.
Parameters
----------
stops : pandas.DataFrame
DataFrame containing the stops information from the GTFS feed.
save_graphml : bool
Flag indicating whether to save the resulting graph as a GraphML file.
path : str
The file path to save the GraphML file (if save_graphml is True).
Returns
-------
G_city : networkx.DiGraph
A street network graph with walking times as edge weights.
"""
# Building a convex hull from stop coordinates for OSM loading
stops_gdf = gpd.GeoDataFrame(
stops, geometry=gpd.points_from_xy(stops.stop_lon, stops.stop_lat)
)
boundary = stops_gdf.unary_union.convex_hull
logger.info("Loading OSM graph via OSMNX")
# Loading OSM data within the convex hull
G_city = ox.graph_from_polygon(boundary, network_type="walk", simplify=True)
attributes_to_keep = {"length", "highway", "name"}
for u, v, key, data in G_city.edges(keys=True, data=True):
# Clean extra attributes
for attribute in list(data):
if attribute not in attributes_to_keep:
del data[attribute]
# Calculate walking time in seconds
data["weight"] = data["length"] / 1.39
data["type"] = "street"
# Add geometry to the edge
u_geom = Point(G_city.nodes[u]["x"], G_city.nodes[u]["y"])
v_geom = Point(G_city.nodes[v]["x"], G_city.nodes[v]["y"])
data["geometry"] = LineString([u_geom, v_geom])
nx.set_node_attributes(G_city, "street", "type")
if save_graphml:
ox.save_graphml(G_city, path)
logger.info("Street network graph created")
return nx.DiGraph(G_city)
[docs]
def feed_to_graph(
GTFSpath: str,
departure_time_input: str,
day_of_week: str,
duration_seconds: int,
read_shapes: bool = False,
multiprocessing: bool = True,
input_graph_path: str = None,
output_graph_path: str = None,
save_graphml: bool = False,
load_graphml: bool = False,
) -> nx.DiGraph:
"""
Creates a directed graph (DiGraph) based on General Transit Feed Specification (GTFS) and OpenStreetMap (OSM) data.
Parameters
----------
GTFSpath : str
Path to the GTFS files.
departure_time_input : str
Departure time in 'HH:MM:SS' format.
day_of_week : str
Day of the week in lowercase (e.g., 'monday').
duration_seconds : int
Time period from departure for which the graph will be loaded.
read_shapes : bool, optional
Flag for reading geometry from shapes.txt file. Default is False. This parameter is currently not working as intended.
multiprocessing : bool, optional
Flag for using multiprocessing. Default is False.
input_graph_path : str, optional
Path to the OSM graph file in GraphML format. Default is None.
output_graph_path : str, optional
Path for saving the OSM graph in GraphML format. Default is None.
save_graphml : bool, optional
Flag for saving the OSM graph in GraphML format. Default is False.
load_graphml : bool, optional
Flag for loading the OSM graph from a GraphML file. Default is False.
Returns
-------
G_combined : nx.DiGraph
Combined multimodal graph representing transit network.
"""
# Validate the GTFS feed
bool_feed_valid = validate_feed(GTFSpath)
if not bool_feed_valid:
raise ValueError("The GTFS feed is not valid")
G_transit, stops = _load_GTFS(
GTFSpath,
departure_time_input,
day_of_week,
duration_seconds,
read_shapes=read_shapes,
multiprocessing=multiprocessing,
)
if load_graphml:
print("Loading OSM graph from GraphML file")
# Dictionary with data types for edges
edge_dtypes = {"weight": float, "length": float}
G_city = ox.load_graphml(input_graph_path, edge_dtypes=edge_dtypes)
G_city = nx.DiGraph(G_city)
else:
# Import OSM data
G_city = _load_osm(stops, save_graphml, output_graph_path)
# Combining OSM and GTFS data
G_combined = nx.compose(G_transit, G_city)
# Filling EPSG:4087 coordinates for graph nodes
_fill_coordinates(G_combined)
# Connecting stops to OSM streets
connect_stops_to_streets(G_combined, stops)
logger.info(
f"Nodes: {G_combined.number_of_nodes()}, Edges: {G_combined.number_of_edges()}"
)
return G_combined
[docs]
def load_stops_gdf(path) -> gpd.GeoDataFrame:
"""
Load stops data from a specified path and return a GeoDataFrame.
Parameters
----------
path: str
The path to the directory containing the stops data.
Returns
-------
stops_gdf: gpd.GeoDataFrame
GeoDataFrame containing the stops data with geometry information.
"""
stops_df = pd.read_csv(os.path.join(path, "stops.txt"))
stops_gdf = gpd.GeoDataFrame(
stops_df,
geometry=gpd.points_from_xy(stops_df.stop_lon, stops_df.stop_lat),
crs="epsg:4326",
)
return stops_gdf