Source code for districtheatingsim.net_generation.minimal_spanning_tree

"""
Minimum Spanning Tree generation and road alignment optimization.

Implements MST algorithms for district heating network layout with iterative
road alignment adjustment while maintaining connectivity.

:author: Dipl.-Ing. (FH) Jonas Pfeiffer
"""

import geopandas as gpd
from shapely.geometry import LineString, Point
from shapely.ops import nearest_points
import networkx as nx
from collections import defaultdict
import numpy as np
from typing import Tuple, Optional, Set, Dict, Any

[docs] def generate_mst(points: gpd.GeoDataFrame) -> gpd.GeoDataFrame: """ Generate Minimum Spanning Tree from point locations. :param points: Point geometries for network terminals :type points: gpd.GeoDataFrame :return: MST network as LineString geometries :rtype: gpd.GeoDataFrame .. note:: Tree topology (n-1 edges for n points). Uses Kruskal's algorithm with Euclidean distances. """ # Build complete graph with distance weights g = nx.Graph() for i, point1 in points.iterrows(): for j, point2 in points.iterrows(): if i != j: distance = point1.geometry.distance(point2.geometry) g.add_edge(i, j, weight=distance) # Generate minimum spanning tree mst = nx.minimum_spanning_tree(g) # Convert MST edges to LineString geometries lines = [ LineString([points.geometry[edge[0]], points.geometry[edge[1]]]) for edge in mst.edges() ] return gpd.GeoDataFrame(geometry=lines)
[docs] def adjust_segments_to_roads(mst_gdf: gpd.GeoDataFrame, street_layer: gpd.GeoDataFrame, all_end_points_gdf: gpd.GeoDataFrame, threshold: float = 5.0, min_improvement: float = 0.5) -> gpd.GeoDataFrame: """ Iteratively adjust MST segments to follow street network. :param mst_gdf: MST segments to optimize :type mst_gdf: gpd.GeoDataFrame :param street_layer: Street network for alignment :type street_layer: gpd.GeoDataFrame :param all_end_points_gdf: Terminal points for MST reconstruction :type all_end_points_gdf: gpd.GeoDataFrame :param threshold: Distance threshold for adjustment trigger [m] (default 5.0) :type threshold: float :param min_improvement: Minimum improvement required [m] (default 0.5) :type min_improvement: float :return: Road-aligned network maintaining MST connectivity :rtype: gpd.GeoDataFrame :raises ValueError: If invalid geometries or empty network :raises RuntimeError: If fails to converge within iteration limit .. note:: Uses iterative improvement with blacklisting to prevent oscillation. Max 50 iterations. """ iteration = 0 changes_made = True max_iterations = 50 # Track segment adjustments and blacklist problematic segments segment_change_counter: Dict[int, int] = {} blacklist: Set[int] = set() def line_hash(line: LineString) -> int: """Create stable hash for line geometry based on rounded coordinates.""" coords = tuple((round(x, 3), round(y, 3)) for x, y in line.coords) return hash(coords) # Main optimization loop while changes_made and iteration < max_iterations: print(f"\n--- Road Alignment Iteration {iteration} ---") adjusted_lines = [] changes_made = False changed_this_iter: Set[int] = set() for idx, line in enumerate(mst_gdf.geometry): # Validate line geometry if not line.is_valid: print(f" [!] Invalid line geometry at index {idx}") continue seg_id = line_hash(line) if seg_id in blacklist: adjusted_lines.append(line) continue # Check if segment needs adjustment midpoint = line.interpolate(0.5, normalized=True) nearest_line_idx = street_layer.distance(midpoint).idxmin() nearest_street = street_layer.iloc[nearest_line_idx].geometry point_on_street = nearest_points(midpoint, nearest_street)[1] distance_to_street = midpoint.distance(point_on_street) if distance_to_street > threshold: # Avoid adjustments where projection point equals line endpoints if (point_on_street.equals(Point(line.coords[0])) or point_on_street.equals(Point(line.coords[1]))): print(f" Skipping adjustment: projected point is endpoint") adjusted_lines.append(line) continue # Calculate improvement for validation orig_distance = distance_to_street # Split line through projected street point new_line1 = LineString([line.coords[0], point_on_street.coords[0]]) new_line2 = LineString([point_on_street.coords[0], line.coords[1]]) # Validate and add new segments for new_line in [new_line1, new_line2]: if new_line.is_valid and not new_line.is_empty: # Check improvement quality new_midpoint = new_line.interpolate(0.5, normalized=True) nearest_line_new_idx = street_layer.distance(new_midpoint).idxmin() nearest_street_new = street_layer.iloc[nearest_line_new_idx].geometry point_on_street_new = nearest_points(new_midpoint, nearest_street_new)[1] new_distance = new_midpoint.distance(point_on_street_new) improvement = orig_distance - new_distance if improvement < min_improvement: print(f" Insufficient improvement ({improvement:.2f}m), blacklisting segment") blacklist.add(line_hash(new_line)) adjusted_lines.append(new_line) else: print(f" [!] Invalid new segment created") changes_made = True changed_this_iter.add(seg_id) segment_change_counter[seg_id] = segment_change_counter.get(seg_id, 0) + 1 print(f" Adjusted segment {idx}, total adjustments: {segment_change_counter[seg_id]}") else: adjusted_lines.append(line) # Progress reporting print(f" Adjusted {len(changed_this_iter)} segments this iteration") if changed_this_iter: most_changed = sorted(segment_change_counter.items(), key=lambda x: -x[1])[:3] print(f" Most frequently adjusted segments: {most_changed}") if not changes_made: print("No changes made, optimization converged") break mst_gdf = gpd.GeoDataFrame(geometry=adjusted_lines) iteration += 1 if iteration >= max_iterations: print(f"Warning: Reached maximum iterations ({max_iterations})") # Post-processing: simplify and rebuild MST print("\nPost-processing: simplifying network and rebuilding MST...") mst_gdf = simplify_network(mst_gdf) mst_gdf = extract_unique_points_and_create_mst(mst_gdf, all_end_points_gdf) print("Road alignment optimization completed") return mst_gdf
[docs] def simplify_network(gdf: gpd.GeoDataFrame, threshold: float = 10.0) -> gpd.GeoDataFrame: """ Simplify network by merging nearby points. :param gdf: Network segments to simplify :type gdf: gpd.GeoDataFrame :param threshold: Distance for merging points [m] (default 10.0) :type threshold: float :return: Simplified network with merged points :rtype: gpd.GeoDataFrame .. note:: Merges points within threshold to their centroid, maintaining connectivity. """ # Dictionary to store points and their associated line indices points = defaultdict(list) simplified_lines = [] # Extract endpoints from all line segments for idx, line in enumerate(gdf.geometry): start, end = line.boundary.geoms points[start].append(idx) points[end].append(idx) # Find and merge nearby points merged_points = {} for point in points: if point in merged_points: continue # Find all points within threshold distance nearby_points = [ p for p in points if p.distance(point) < threshold and p not in merged_points ] if not nearby_points: merged_points[point] = point else: # Calculate centroid of nearby points all_points = np.array([[p.x, p.y] for p in nearby_points]) centroid = Point(np.mean(all_points, axis=0)) # Map all nearby points to centroid for p in nearby_points: merged_points[p] = centroid # Reconstruct lines with merged endpoints for line in gdf.geometry: start, end = line.boundary.geoms new_start = merged_points.get(start, start) new_end = merged_points.get(end, end) # Create new line with updated endpoints if not new_start.equals(new_end): # Avoid zero-length lines simplified_lines.append(LineString([new_start, new_end])) return gpd.GeoDataFrame(geometry=simplified_lines)
[docs] def extract_unique_points_and_create_mst(gdf: gpd.GeoDataFrame, all_end_points_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: """ Extract unique points and rebuild MST structure. :param gdf: Network segments for point extraction :type gdf: gpd.GeoDataFrame :param all_end_points_gdf: Terminal points to preserve :type all_end_points_gdf: gpd.GeoDataFrame :return: Reconstructed MST connecting all unique points :rtype: gpd.GeoDataFrame .. note:: Ensures tree topology (n-1 edges for n points) after network adjustments. """ # Extract coordinates from all line geometries all_points = [] for line in gdf.geometry: if isinstance(line, LineString): all_points.extend(line.coords) # Add terminal points from endpoint GeoDataFrame for point in all_end_points_gdf.geometry: if isinstance(point, Point): all_points.append((point.x, point.y)) # Remove duplicates and convert to Point objects unique_points = set(all_points) unique_points = [Point(pt) for pt in unique_points] # Create GeoDataFrame from unique points points_gdf = gpd.GeoDataFrame(geometry=unique_points) # Generate new MST from unique points mst_gdf = generate_mst(points_gdf) return mst_gdf