💻 System Design - Python Implementations

Working code examples for distributed systems, algorithms, and components

Purpose: This guide provides functional Python implementations of key system design components and algorithms. Each example includes complexity analysis, use cases, and production considerations.

🗺️ 1. Routing & Graph Algorithms

Essential for services like Uber, Google Maps, delivery routing, and network optimization.

Dijkstra's Algorithm - Shortest Path

Use Case: Finding shortest route from origin to destination (Uber, Google Maps, network routing)
import heapq
from typing import Dict, List, Tuple

class Graph:
    def __init__(self):
        # Adjacency list: node -> [(neighbor, weight), ...]
        self.graph: Dict[str, List[Tuple[str, int]]] = {}

    def add_edge(self, from_node: str, to_node: str, weight: int):
        """Add a directed edge with weight"""
        if from_node not in self.graph:
            self.graph[from_node] = []
        self.graph[from_node].append((to_node, weight))

    def dijkstra(self, start: str, end: str) -> Tuple[int, List[str]]:
        """
        Find shortest path from start to end using Dijkstra's algorithm.

        Returns: (total_distance, path)

        Time Complexity: O((V + E) log V) where V = vertices, E = edges
        Space Complexity: O(V)
        """
        # Min heap: (distance, node, path)
        heap = [(0, start, [start])]
        visited = set()
        distances = {start: 0}

        while heap:
            current_dist, current_node, path = heapq.heappop(heap)

            if current_node in visited:
                continue

            visited.add(current_node)

            # Found destination
            if current_node == end:
                return current_dist, path

            # Check neighbors
            if current_node in self.graph:
                for neighbor, weight in self.graph[current_node]:
                    if neighbor not in visited:
                        new_dist = current_dist + weight

                        # Only process if we found a shorter path
                        if neighbor not in distances or new_dist < distances[neighbor]:
                            distances[neighbor] = new_dist
                            new_path = path + [neighbor]
                            heapq.heappush(heap, (new_dist, neighbor, new_path))

        return float('inf'), []  # No path found

# Example: City road network
g = Graph()
g.add_edge("A", "B", 4)
g.add_edge("A", "C", 2)
g.add_edge("B", "C", 1)
g.add_edge("B", "D", 5)
g.add_edge("C", "D", 8)
g.add_edge("C", "E", 10)
g.add_edge("D", "E", 2)

distance, path = g.dijkstra("A", "E")
print(f"Shortest distance from A to E: {distance}")  # Output: 11
print(f"Path: {' -> '.join(path)}")  # Output: A -> C -> B -> D -> E
Time Complexity: O((V + E) log V) with min-heap
Space Complexity: O(V) for distances and visited set
When to use: Non-negative edge weights, need actual shortest path

A* Algorithm - Optimized Pathfinding

Use Case: Google Maps navigation with heuristic (straight-line distance), game pathfinding
import heapq
import math
from typing import Dict, List, Tuple

class AStarGraph:
    def __init__(self):
        self.graph: Dict[str, List[Tuple[str, int]]] = {}
        self.coordinates: Dict[str, Tuple[float, float]] = {}

    def add_edge(self, from_node: str, to_node: str, weight: int):
        if from_node not in self.graph:
            self.graph[from_node] = []
        self.graph[from_node].append((to_node, weight))

    def set_coordinates(self, node: str, x: float, y: float):
        """Set physical coordinates for heuristic calculation"""
        self.coordinates[node] = (x, y)

    def heuristic(self, node1: str, node2: str) -> float:
        """Euclidean distance as heuristic (admissible for straight-line distance)"""
        x1, y1 = self.coordinates[node1]
        x2, y2 = self.coordinates[node2]
        return math.sqrt((x2 - x1)**2 + (y2 - y1)**2)

    def a_star(self, start: str, goal: str) -> Tuple[int, List[str]]:
        """
        A* pathfinding algorithm.
        f(n) = g(n) + h(n) where:
        - g(n) = actual cost from start to n
        - h(n) = heuristic estimate from n to goal

        Time Complexity: O((V + E) log V) - often faster than Dijkstra in practice
        Space Complexity: O(V)
        """
        # Min heap: (f_score, g_score, node, path)
        heap = [(0, 0, start, [start])]
        visited = set()
        g_scores = {start: 0}

        while heap:
            f_score, g_score, current, path = heapq.heappop(heap)

            if current in visited:
                continue

            visited.add(current)

            # Found goal
            if current == goal:
                return g_score, path

            # Explore neighbors
            if current in self.graph:
                for neighbor, weight in self.graph[current]:
                    if neighbor not in visited:
                        new_g_score = g_score + weight

                        if neighbor not in g_scores or new_g_score < g_scores[neighbor]:
                            g_scores[neighbor] = new_g_score
                            h_score = self.heuristic(neighbor, goal)
                            f_score = new_g_score + h_score
                            new_path = path + [neighbor]
                            heapq.heappush(heap, (f_score, new_g_score, neighbor, new_path))

        return float('inf'), []

# Example: City grid with coordinates
g = AStarGraph()

# Set coordinates (x, y)
g.set_coordinates("A", 0, 0)
g.set_coordinates("B", 4, 1)
g.set_coordinates("C", 2, 2)
g.set_coordinates("D", 6, 4)
g.set_coordinates("E", 8, 5)

# Add edges (same as before)
g.add_edge("A", "B", 4)
g.add_edge("A", "C", 2)
g.add_edge("B", "C", 1)
g.add_edge("B", "D", 5)
g.add_edge("C", "D", 8)
g.add_edge("C", "E", 10)
g.add_edge("D", "E", 2)

distance, path = g.a_star("A", "E")
print(f"A* distance: {distance}")
print(f"A* path: {' -> '.join(path)}")
A* vs Dijkstra: A* is faster when you have a good heuristic (like straight-line distance in maps). It explores fewer nodes by prioritizing paths that move toward the goal. Dijkstra explores uniformly in all directions.

Traveling Salesman Problem (TSP) - Approximation

Use Case: Delivery route optimization (UPS, DoorDash), tour planning, circuit board drilling
import itertools
from typing import List, Tuple

class TSP:
    """Traveling Salesman Problem - find shortest route visiting all cities"""

    def __init__(self, cities: List[str], distances: Dict[Tuple[str, str], int]):
        """
        cities: List of city names
        distances: {(city1, city2): distance} dictionary
        """
        self.cities = cities
        self.distances = distances

    def brute_force(self) -> Tuple[int, List[str]]:
        """
        Brute force: try all permutations.

        Time Complexity: O(n!) - only feasible for n < 12
        Space Complexity: O(n)
        """
        min_distance = float('inf')
        best_route = []

        # Try all permutations
        for perm in itertools.permutations(self.cities[1:]):
            route = [self.cities[0]] + list(perm) + [self.cities[0]]
            distance = self._calculate_route_distance(route)

            if distance < min_distance:
                min_distance = distance
                best_route = route

        return min_distance, best_route

    def nearest_neighbor(self, start: str = None) -> Tuple[int, List[str]]:
        """
        Greedy approximation: always go to nearest unvisited city.

        Time Complexity: O(n²)
        Space Complexity: O(n)

        Approximation ratio: Can be up to 25% longer than optimal
        """
        if start is None:
            start = self.cities[0]

        unvisited = set(self.cities)
        current = start
        route = [current]
        unvisited.remove(current)
        total_distance = 0

        while unvisited:
            # Find nearest unvisited city
            nearest = min(unvisited,
                         key=lambda city: self.distances.get((current, city), float('inf')))

            distance = self.distances.get((current, nearest), 0)
            total_distance += distance
            route.append(nearest)
            unvisited.remove(nearest)
            current = nearest

        # Return to start
        total_distance += self.distances.get((current, start), 0)
        route.append(start)

        return total_distance, route

    def two_opt(self, route: List[str], max_iterations: int = 1000) -> Tuple[int, List[str]]:
        """
        2-opt improvement: repeatedly swap edges to reduce distance.

        Time Complexity: O(n² × iterations)
        Space Complexity: O(n)

        Much better than nearest neighbor, often within 5% of optimal
        """
        best_route = route[:]
        best_distance = self._calculate_route_distance(best_route)
        improved = True
        iteration = 0

        while improved and iteration < max_iterations:
            improved = False
            iteration += 1

            for i in range(1, len(best_route) - 2):
                for j in range(i + 1, len(best_route) - 1):
                    # Try reversing segment between i and j
                    new_route = best_route[:i] + best_route[i:j+1][::-1] + best_route[j+1:]
                    new_distance = self._calculate_route_distance(new_route)

                    if new_distance < best_distance:
                        best_route = new_route
                        best_distance = new_distance
                        improved = True
                        break

                if improved:
                    break

        return best_distance, best_route

    def solve(self) -> Tuple[int, List[str]]:
        """
        Practical TSP solver: Nearest neighbor + 2-opt improvement

        Time Complexity: O(n³) in practice
        Works well for n < 1000
        """
        # Start with nearest neighbor
        distance, route = self.nearest_neighbor()

        # Improve with 2-opt
        distance, route = self.two_opt(route)

        return distance, route

    def _calculate_route_distance(self, route: List[str]) -> int:
        """Calculate total distance of a route"""
        total = 0
        for i in range(len(route) - 1):
            total += self.distances.get((route[i], route[i+1]), 0)
        return total

# Example: Delivery route optimization
cities = ["Warehouse", "Customer1", "Customer2", "Customer3", "Customer4"]

# Distance matrix (symmetric)
distances = {
    ("Warehouse", "Customer1"): 10, ("Customer1", "Warehouse"): 10,
    ("Warehouse", "Customer2"): 15, ("Customer2", "Warehouse"): 15,
    ("Warehouse", "Customer3"): 20, ("Customer3", "Warehouse"): 20,
    ("Warehouse", "Customer4"): 25, ("Customer4", "Warehouse"): 25,
    ("Customer1", "Customer2"): 35, ("Customer2", "Customer1"): 35,
    ("Customer1", "Customer3"): 20, ("Customer3", "Customer1"): 20,
    ("Customer1", "Customer4"): 30, ("Customer4", "Customer1"): 30,
    ("Customer2", "Customer3"): 30, ("Customer3", "Customer2"): 30,
    ("Customer2", "Customer4"): 12, ("Customer4", "Customer2"): 12,
    ("Customer3", "Customer4"): 15, ("Customer4", "Customer3"): 15,
}

tsp = TSP(cities, distances)

# Solve with hybrid approach
distance, route = tsp.solve()
print(f"Optimized delivery route distance: {distance}")
print(f"Route: {' -> '.join(route)}")

# For small datasets, compare with brute force
if len(cities) <= 8:
    optimal_distance, optimal_route = tsp.brute_force()
    print(f"\nOptimal (brute force) distance: {optimal_distance}")
    print(f"Optimal route: {' -> '.join(optimal_route)}")
TSP is NP-Hard! No polynomial-time exact algorithm exists for large inputs. In production:
  • Small (n < 12): Use brute force or dynamic programming
  • Medium (n < 1000): Use nearest neighbor + 2-opt
  • Large (n > 1000): Use advanced heuristics (Christofides, genetic algorithms, simulated annealing)

📍 2. Geospatial Data Structures

Critical for location-based services like Uber, Yelp, real estate apps.

Geohash Implementation

Use Case: Uber driver search, Yelp "restaurants near me", indexing geographic data
class Geohash:
    """
    Geohash encodes latitude/longitude into a short string.
    Nearby locations share common prefixes.

    Example:
    - San Francisco: 9q8yy
    - Nearby location: 9q8yz (shares prefix 9q8y)
    - Far location: dr5ru (different prefix)
    """

    BASE32 = "0123456789bcdefghjkmnpqrstuvwxyz"

    @staticmethod
    def encode(latitude: float, longitude: float, precision: int = 6) -> str:
        """
        Encode lat/lon to geohash string.

        precision levels:
        1: ±2500 km
        2: ±630 km
        3: ±78 km
        4: ±20 km
        5: ±2.4 km
        6: ±610 m (good for city-level queries)
        7: ±76 m
        8: ±19 m

        Time Complexity: O(precision)
        """
        lat_range = [-90.0, 90.0]
        lon_range = [-180.0, 180.0]
        geohash = []
        bits = 0
        bit = 0
        even = True

        while len(geohash) < precision:
            if even:
                # Longitude
                mid = (lon_range[0] + lon_range[1]) / 2
                if longitude > mid:
                    bit |= (1 << (4 - bits))
                    lon_range[0] = mid
                else:
                    lon_range[1] = mid
            else:
                # Latitude
                mid = (lat_range[0] + lat_range[1]) / 2
                if latitude > mid:
                    bit |= (1 << (4 - bits))
                    lat_range[0] = mid
                else:
                    lat_range[1] = mid

            even = not even

            bits += 1
            if bits == 5:
                geohash.append(Geohash.BASE32[bit])
                bits = 0
                bit = 0

        return ''.join(geohash)

    @staticmethod
    def decode(geohash: str) -> tuple[float, float]:
        """
        Decode geohash to (latitude, longitude).

        Time Complexity: O(len(geohash))
        """
        lat_range = [-90.0, 90.0]
        lon_range = [-180.0, 180.0]
        even = True

        for char in geohash:
            idx = Geohash.BASE32.index(char)

            for i in range(4, -1, -1):
                bit = (idx >> i) & 1

                if even:
                    # Longitude
                    mid = (lon_range[0] + lon_range[1]) / 2
                    if bit:
                        lon_range[0] = mid
                    else:
                        lon_range[1] = mid
                else:
                    # Latitude
                    mid = (lat_range[0] + lat_range[1]) / 2
                    if bit:
                        lat_range[0] = mid
                    else:
                        lat_range[1] = mid

                even = not even

        lat = (lat_range[0] + lat_range[1]) / 2
        lon = (lon_range[0] + lon_range[1]) / 2
        return lat, lon

    @staticmethod
    def neighbors(geohash: str) -> List[str]:
        """
        Get 8 neighboring geohash cells (N, NE, E, SE, S, SW, W, NW).

        Use for "find nearby" queries:
        1. Get geohash of query point
        2. Get all 9 cells (self + 8 neighbors)
        3. Query database WHERE geohash IN (cells)
        """
        # Simplified neighbor calculation
        lat, lon = Geohash.decode(geohash)
        precision = len(geohash)

        # Approximate cell size at this precision
        lat_delta = 180.0 / (2 ** (precision * 2.5))
        lon_delta = 360.0 / (2 ** (precision * 2.5))

        neighbors = []
        for dlat in [-lat_delta, 0, lat_delta]:
            for dlon in [-lon_delta, 0, lon_delta]:
                if dlat == 0 and dlon == 0:
                    continue
                neighbor_hash = Geohash.encode(lat + dlat, lon + dlon, precision)
                neighbors.append(neighbor_hash)

        return neighbors

# Example: Uber driver location indexing
print("=== Geohash Examples ===")

# San Francisco coordinates
sf_lat, sf_lon = 37.7749, -122.4194
sf_geohash = Geohash.encode(sf_lat, sf_lon, precision=6)
print(f"San Francisco: {sf_geohash}")

# Nearby location (Oakland)
oakland_lat, oakland_lon = 37.8044, -122.2712
oakland_geohash = Geohash.encode(oakland_lat, oakland_lon, precision=6)
print(f"Oakland: {oakland_geohash}")
print(f"Common prefix: {sf_geohash[:4]}")  # First 4 chars same = ~20km radius

# Far location (New York)
ny_lat, ny_lon = 40.7128, -74.0060
ny_geohash = Geohash.encode(ny_lat, ny_lon, precision=6)
print(f"New York: {ny_geohash}")
print(f"Different prefix: no common prefix with SF")

# Find nearby cells for query
print(f"\nNeighbors of {sf_geohash}:")
for neighbor in Geohash.neighbors(sf_geohash):
    print(f"  {neighbor}")
Uber's Geohash Strategy:
  • Driver locations indexed by geohash (precision 6-7 for city-level)
  • When rider requests: Calculate rider's geohash + 8 neighbors
  • Query: SELECT * FROM drivers WHERE geohash IN (cells) AND available=true
  • Returns ~100-200 drivers instantly (indexed query)
  • Then calculate exact distances for top candidates

QuadTree - Spatial Indexing

Use Case: Game engines, map rendering, collision detection, spatial databases
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class Point:
    x: float
    y: float
    data: any = None

@dataclass
class Rectangle:
    x: float  # center x
    y: float  # center y
    width: float
    height: float

    def contains(self, point: Point) -> bool:
        """Check if point is inside rectangle"""
        return (self.x - self.width/2 <= point.x <= self.x + self.width/2 and
                self.y - self.height/2 <= point.y <= self.y + self.height/2)

    def intersects(self, range: 'Rectangle') -> bool:
        """Check if two rectangles intersect"""
        return not (range.x - range.width/2 > self.x + self.width/2 or
                   range.x + range.width/2 < self.x - self.width/2 or
                   range.y - range.height/2 > self.y + self.height/2 or
                   range.y + range.height/2 < self.y - self.height/2)

class QuadTree:
    """
    QuadTree for efficient 2D spatial queries.

    Each node can hold up to `capacity` points. When exceeded, it subdivides
    into 4 quadrants: NW, NE, SW, SE.

    Time Complexity:
    - Insert: O(log n) average, O(n) worst case
    - Query range: O(log n + k) where k = results found

    Space Complexity: O(n)
    """

    def __init__(self, boundary: Rectangle, capacity: int = 4):
        self.boundary = boundary
        self.capacity = capacity
        self.points: List[Point] = []
        self.divided = False

        # Children (created when subdivided)
        self.northwest: Optional[QuadTree] = None
        self.northeast: Optional[QuadTree] = None
        self.southwest: Optional[QuadTree] = None
        self.southeast: Optional[QuadTree] = None

    def insert(self, point: Point) -> bool:
        """Insert a point into the quadtree"""
        # Ignore if point is not in boundary
        if not self.boundary.contains(point):
            return False

        # If not at capacity, add to this node
        if len(self.points) < self.capacity:
            self.points.append(point)
            return True

        # At capacity - subdivide if not already divided
        if not self.divided:
            self.subdivide()

        # Insert into appropriate child
        return (self.northwest.insert(point) or
                self.northeast.insert(point) or
                self.southwest.insert(point) or
                self.southeast.insert(point))

    def subdivide(self):
        """Split this node into 4 quadrants"""
        x = self.boundary.x
        y = self.boundary.y
        w = self.boundary.width / 2
        h = self.boundary.height / 2

        nw = Rectangle(x - w/2, y + h/2, w, h)
        ne = Rectangle(x + w/2, y + h/2, w, h)
        sw = Rectangle(x - w/2, y - h/2, w, h)
        se = Rectangle(x + w/2, y - h/2, w, h)

        self.northwest = QuadTree(nw, self.capacity)
        self.northeast = QuadTree(ne, self.capacity)
        self.southwest = QuadTree(sw, self.capacity)
        self.southeast = QuadTree(se, self.capacity)

        self.divided = True

    def query_range(self, range: Rectangle) -> List[Point]:
        """
        Find all points within a given rectangle.

        Used for: "Find all restaurants within visible map bounds"
        """
        found = []

        # No intersection = no results
        if not self.boundary.intersects(range):
            return found

        # Check points in this node
        for point in self.points:
            if range.contains(point):
                found.append(point)

        # Recursively check children if subdivided
        if self.divided:
            found.extend(self.northwest.query_range(range))
            found.extend(self.northeast.query_range(range))
            found.extend(self.southwest.query_range(range))
            found.extend(self.southeast.query_range(range))

        return found

# Example: Restaurant locations in a city
print("\n=== QuadTree Example ===")

# City bounds: 100km x 100km
boundary = Rectangle(0, 0, 100, 100)
qt = QuadTree(boundary, capacity=4)

# Insert restaurants
restaurants = [
    Point(10, 10, "Pizza Palace"),
    Point(15, 12, "Burger Joint"),
    Point(11, 11, "Sushi Bar"),
    Point(-20, 30, "Taco Stand"),
    Point(-25, 35, "Noodle House"),
    Point(40, -40, "Steakhouse"),
    Point(42, -38, "Cafe"),
]

for r in restaurants:
    qt.insert(r)

# Query: Find restaurants in viewport (10x10 area around origin)
viewport = Rectangle(0, 0, 20, 20)
results = qt.query_range(viewport)

print(f"Restaurants in viewport: {len(results)}")
for r in results:
    print(f"  {r.data} at ({r.x}, {r.y})")
QuadTree vs Geohash:
QuadTree: Dynamic, adapts to data distribution, better for uneven distributions
Geohash: Static grid, simpler to implement, works well with databases (just a string column + index)
In Production: Most use Geohash for simplicity. QuadTree used in games/graphics where you control the data structure.

🔄 3. Caching & Data Structures

LRU Cache Implementation

Use Case: Redis, browser cache, database query cache, CDN
from collections import OrderedDict
from typing import Any, Optional

class LRUCache:
    """
    Least Recently Used (LRU) Cache with O(1) get and put.

    Implementation:
    - OrderedDict maintains insertion order
    - move_to_end() makes an item "most recently used"
    - popitem(last=False) removes least recently used

    Time Complexity: O(1) for get, put, delete
    Space Complexity: O(capacity)
    """

    def __init__(self, capacity: int):
        self.cache = OrderedDict()
        self.capacity = capacity

    def get(self, key: str) -> Optional[Any]:
        """Get value by key, mark as recently used"""
        if key not in self.cache:
            return None

        # Move to end (most recently used)
        self.cache.move_to_end(key)
        return self.cache[key]

    def put(self, key: str, value: Any) -> None:
        """Put key-value pair, evict LRU if at capacity"""
        if key in self.cache:
            # Update existing key
            self.cache.move_to_end(key)
        else:
            # New key
            if len(self.cache) >= self.capacity:
                # Evict least recently used (first item)
                evicted_key, evicted_value = self.cache.popitem(last=False)
                print(f"Evicted: {evicted_key}")

        self.cache[key] = value

    def __repr__(self):
        return f"LRUCache({list(self.cache.keys())})"

# Alternative: Manual implementation with doubly-linked list + hash map
class Node:
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.prev = None
        self.next = None

class LRUCacheManual:
    """
    LRU Cache using doubly-linked list + hash map.

    This is the implementation you'd explain in an interview.
    """

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = {}  # key -> Node

        # Dummy head and tail for easy insertion/removal
        self.head = Node(0, 0)
        self.tail = Node(0, 0)
        self.head.next = self.tail
        self.tail.prev = self.head

    def get(self, key: str) -> Optional[Any]:
        if key not in self.cache:
            return None

        node = self.cache[key]
        # Move to front (most recently used)
        self._remove(node)
        self._add_to_front(node)
        return node.value

    def put(self, key: str, value: Any) -> None:
        if key in self.cache:
            # Update existing
            self._remove(self.cache[key])

        node = Node(key, value)
        self._add_to_front(node)
        self.cache[key] = node

        if len(self.cache) > self.capacity:
            # Evict LRU (tail.prev)
            lru = self.tail.prev
            self._remove(lru)
            del self.cache[lru.key]

    def _remove(self, node: Node):
        """Remove node from linked list"""
        node.prev.next = node.next
        node.next.prev = node.prev

    def _add_to_front(self, node: Node):
        """Add node right after head (most recently used position)"""
        node.next = self.head.next
        node.prev = self.head
        self.head.next.prev = node
        self.head.next = node

# Example usage
cache = LRUCache(3)

cache.put("user:1", {"name": "Alice", "age": 30})
cache.put("user:2", {"name": "Bob", "age": 25})
cache.put("user:3", {"name": "Charlie", "age": 35})
print(cache)  # ['user:1', 'user:2', 'user:3']

cache.get("user:1")  # Access user:1 (now most recently used)
print(cache)  # ['user:2', 'user:3', 'user:1']

cache.put("user:4", {"name": "David", "age": 28})  # Evicts user:2 (LRU)
print(cache)  # ['user:3', 'user:1', 'user:4']

Consistent Hashing

Use Case: Distributed caching (Memcached, Redis Cluster), load balancing, CDN
import hashlib
import bisect
from typing import List, Optional

class ConsistentHash:
    """
    Consistent Hashing for distributed systems.

    Problem: With normal hashing (key % N), adding/removing servers
    causes massive cache invalidation (almost all keys rehash).

    Solution: Map both servers and keys to a ring [0, 2^32).
    Each key goes to the next server clockwise on the ring.

    Adding/removing a server only affects ~1/N keys (minimal disruption).

    Virtual nodes: Each physical server maps to multiple points on ring
    for better load distribution.
    """

    def __init__(self, nodes: List[str] = None, virtual_nodes: int = 150):
        """
        nodes: List of server names
        virtual_nodes: Number of virtual nodes per physical node
                      (higher = better distribution, but more memory)
        """
        self.virtual_nodes = virtual_nodes
        self.ring = []  # Sorted list of hash values
        self.ring_map = {}  # hash -> actual node name
        self.nodes = set()

        if nodes:
            for node in nodes:
                self.add_node(node)

    def _hash(self, key: str) -> int:
        """Hash function: key -> integer in [0, 2^32)"""
        return int(hashlib.md5(key.encode()).hexdigest(), 16) % (2**32)

    def add_node(self, node: str):
        """Add a server to the ring"""
        self.nodes.add(node)

        # Add virtual nodes
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}:{i}"
            hash_value = self._hash(virtual_key)

            self.ring.append(hash_value)
            self.ring_map[hash_value] = node

        # Keep ring sorted for binary search
        self.ring.sort()

    def remove_node(self, node: str):
        """Remove a server from the ring"""
        self.nodes.discard(node)

        # Remove virtual nodes
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}:{i}"
            hash_value = self._hash(virtual_key)

            if hash_value in self.ring_map:
                index = bisect.bisect_left(self.ring, hash_value)
                del self.ring[index]
                del self.ring_map[hash_value]

    def get_node(self, key: str) -> Optional[str]:
        """Get the server responsible for this key"""
        if not self.ring:
            return None

        hash_value = self._hash(key)

        # Find next server clockwise on ring (binary search)
        index = bisect.bisect_right(self.ring, hash_value)

        # Wrap around if at end
        if index == len(self.ring):
            index = 0

        return self.ring_map[self.ring[index]]

# Example: Distributed cache with 3 servers
ch = ConsistentHash(["server1", "server2", "server3"])

# Simulate cache key distribution
keys = [f"user:{i}" for i in range(1, 21)]

print("=== Initial Distribution ===")
distribution = {}
for key in keys:
    server = ch.get_node(key)
    distribution[server] = distribution.get(server, 0) + 1
    print(f"{key} -> {server}")

print(f"\nDistribution: {distribution}")

# Add a new server - see how many keys move
print("\n=== Adding server4 ===")
ch.add_node("server4")

moved = 0
new_distribution = {}
for key in keys:
    new_server = ch.get_node(key)
    new_distribution[new_server] = new_distribution.get(new_server, 0) + 1
    if new_server != (lambda k: ch2.get_node(k) if 'ch2' in locals() else None)(key):
        moved += 1

print(f"New distribution: {new_distribution}")
print(f"Keys moved: ~{moved} out of {len(keys)} (~{moved/len(keys)*100:.0f}%)")
print(f"Expected: ~{100/4:.0f}% with 4 servers")
Real-World Example - Adding Cache Server:
Normal Hash (key % 3): Adding 4th server means key % 4 → ~75% of keys rehash!
Consistent Hash: Only ~25% of keys move (from 3 existing servers to new server)
Result: Minimal cache invalidation, smooth scaling
← Back to Study Guide

Continue to Part 2

This guide continues with:

  • Bloom Filters & Count-Min Sketch
  • URL Shortener Implementation
  • Rate Limiter (Token Bucket, Sliding Window)
  • Fan-out Service (News Feed)
  • Autocomplete/Trie
  • Distributed Systems Patterns

Note: Due to length, create a Part 2 HTML file for remaining implementations.