From f4a354f5176ba0942c5c17204c8e7e422e71258b Mon Sep 17 00:00:00 2001 From: samuelarogbonlo Date: Fri, 10 Jan 2025 13:10:25 +0100 Subject: [PATCH] Add GitHub Actions workflows for CI/CD and also did clinting cleanups for the whole codebase --- .github/workflows/ci.yml | 4 +- .github/workflows/publish.yml | 1 - pyproject.toml | 28 +-- setup.cfg | 48 +++++ setup.py | 9 +- src/__init__.py | 0 src/knetvis/__init__.py | 14 ++ src/knetvis/cli.py | 73 +++++++ src/knetvis/models.py | 26 +++ src/{ => knetvis}/policy.py | 111 ++++++---- src/knetvis/simulator.py | 223 ++++++++++++++++++++ src/knetvis/visualizer.py | 387 ++++++++++++++++++++++++++++++++++ src/main.py | 110 ---------- src/simulator.py | 176 ---------------- src/visualizer.py | 356 ------------------------------- tests/test_policy.py | 28 ++- tests/test_simulator.py | 28 ++- tests/test_visualizer.py | 45 ++-- 18 files changed, 911 insertions(+), 756 deletions(-) create mode 100644 setup.cfg delete mode 100644 src/__init__.py create mode 100644 src/knetvis/__init__.py create mode 100644 src/knetvis/cli.py create mode 100644 src/knetvis/models.py rename src/{ => knetvis}/policy.py (54%) create mode 100644 src/knetvis/simulator.py create mode 100644 src/knetvis/visualizer.py delete mode 100644 src/main.py delete mode 100644 src/simulator.py delete mode 100644 src/visualizer.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25c8978..08a01d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,3 @@ -# .github/workflows/ci.yml name: CI on: @@ -66,5 +65,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: dist - path: dist/ - + path: dist/ \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 9e69a09..c608b7a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,4 +1,3 @@ -# .github/workflows/publish.yml name: Publish to PyPI on: diff --git a/pyproject.toml b/pyproject.toml index 2dda6f9..b1d25cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,16 @@ [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" [project] name = "knetvis" version = "0.1.0" -description = "A CLI tool for network visualization" -readme = "README.md" -requires-python = ">=3.7" authors = [ { name = "Samuel Arogbonlo", email = "sbayo971@gmail.com" }, ] +description = "A CLI tool for Kubernetes Network Policy visualization" +readme = "README.md" +requires-python = ">=3.8" dependencies = [ "kubernetes>=28.1.0", "networkx>=3.1", @@ -21,19 +21,7 @@ dependencies = [ ] [project.scripts] -knetvis = "knetvis.cli:main" - -[tool.hatch.build.targets.wheel] -packages = ["src"] +knetvis = "knetvis.cli:cli" -[project.optional-dependencies] -dev = [ - "pytest>=7.3.1", - "pytest-cov>=4.1.0", - "black>=22.3.0", - "isort>=5.10.1", - "flake8>=4.0.1", - "pre-commit>=2.20.0", - "tox>=3.24.0", - "mypy>=0.982", -] \ No newline at end of file +[tool.setuptools.packages.find] +where = ["src"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0341f16 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,48 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 +per-file-ignores = + __init__.py:F401 +exclude = .git,__pycache__,build,dist + +[mypy] +python_version = 3.8 +warn_return_any = True +warn_unused_configs = True +disallow_untyped_defs = True + +[mypy-kubernetes.*] +ignore_missing_imports = True + +[mypy-networkx.*] +ignore_missing_imports = True + +[mypy-yaml.*] +ignore_missing_imports = True + +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_functions = test_* +addopts = --cov=src --cov-report=term-missing + +[coverage:run] +source = src + +[coverage:report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + raise NotImplementedError + if __name__ == .__main__.: + pass + raise ImportError + +[isort] +profile = black +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 \ No newline at end of file diff --git a/setup.py b/setup.py index 557b934..ad5560e 100644 --- a/setup.py +++ b/setup.py @@ -6,13 +6,14 @@ setup( name="knetvis", version="0.1.0", - author="Your Name", - author_email="your.email@example.com", + author="Samuel Arogbonlo", + author_email="sbayo971@gmail.com", description="Kubernetes Network Policy Visualization Tool", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/yourusername/knetvis", - packages=find_packages(), + package_dir={"": "src"}, # Add this line - tells setuptools packages are in src directory + packages=find_packages(where="src"), # Update this line - look for packages in src directory classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -35,7 +36,7 @@ ], entry_points={ "console_scripts": [ - "knetvis=src.main:cli", + "knetvis=knetvis.cli:cli", ], }, ) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/knetvis/__init__.py b/src/knetvis/__init__.py new file mode 100644 index 0000000..c3132e4 --- /dev/null +++ b/src/knetvis/__init__.py @@ -0,0 +1,14 @@ +# src/knetvis/__init__.py + +from .policy import PolicyParser +from .simulator import TrafficSimulator +from .visualizer import NetworkVisualizer + +__version__ = "0.1.0" + +# Export these classes as main package interfaces +__all__ = [ + "PolicyParser", + "TrafficSimulator", + "NetworkVisualizer", +] diff --git a/src/knetvis/cli.py b/src/knetvis/cli.py new file mode 100644 index 0000000..c056c3b --- /dev/null +++ b/src/knetvis/cli.py @@ -0,0 +1,73 @@ +import os + +import click +from rich.console import Console + +from .models import Target +from .policy import PolicyParser +from .simulator import TrafficSimulator + +console = Console() + + +@click.group() +def cli() -> None: + """knetvis - Kubernetes Network Policy Visualization Tool""" + pass + + +@cli.command() +@click.argument("source") +@click.argument("destination") +def test(source: str, destination: str) -> None: + """Test connectivity between resources.""" + try: + source_target = Target.from_str(source) + dest_target = Target.from_str(destination) + parser = PolicyParser() + simulator = TrafficSimulator(parser) + + if not simulator.check_resource_exists(source_target): + console.print(f"[red]Error: Source resource {source} not found[/red]") + return + if not simulator.check_resource_exists(dest_target): + console.print( + f"[red]Error: Destination resource {destination} not found[/red]" + ) + return + + allowed = simulator.test_connectivity(source_target, dest_target) + + if allowed: + console.print("[green]✓ Traffic is allowed[/green]") + else: + console.print("[red]✗ Traffic is blocked[/red]") + + except Exception as e: + console.print(f"[red]Error: {str(e)}[/red]") + + +@cli.command() +@click.argument("policy-file") +def validate(policy_file: str) -> None: + """Validate a network policy file""" + try: + if not os.path.exists(policy_file): + console.print(f"[red]Error: File '{policy_file}' does not exist[/red]") + return + + parser = PolicyParser() + is_valid, message = parser.validate_policy(policy_file) + + if is_valid: + console.print("[green]✓ Policy is valid[/green]") + else: + console.print("[yellow]Policy has potential issues:[/yellow]") + console.print(f"[yellow]{message}[/yellow]") + + except Exception as e: + console.print(f"[red]Error: {str(e)}[/red]") + + +if __name__ == "__main__": + cli() diff --git a/src/knetvis/models.py b/src/knetvis/models.py new file mode 100644 index 0000000..7b02b1a --- /dev/null +++ b/src/knetvis/models.py @@ -0,0 +1,26 @@ +import re +from dataclasses import dataclass +from typing import ClassVar, Pattern + + +@dataclass +class Target: + namespace: str + kind: str + name: str + TARGET_PATTERN: ClassVar[Pattern] = re.compile(r"^(?:([^/]+)/)?([^/]+)/([^/]+)$") + + @classmethod + def from_str(cls, target_str: str) -> "Target": + match = cls.TARGET_PATTERN.match(target_str) + if not match: + raise ValueError( + f"Invalid target format: {target_str}. " + "Expected format: [namespace/]kind/name" + ) + namespace, kind, name = match.groups() + return cls(namespace=namespace or "default", kind=kind, name=name) + + def __str__(self) -> str: + """Return string representation in format namespace/kind/name""" + return f"{self.namespace}/{self.kind}/{self.name}" diff --git a/src/policy.py b/src/knetvis/policy.py similarity index 54% rename from src/policy.py rename to src/knetvis/policy.py index 176c5f6..a0322ed 100644 --- a/src/policy.py +++ b/src/knetvis/policy.py @@ -1,22 +1,21 @@ -from typing import List, Dict, Optional, Tuple -from kubernetes import client, config +from typing import List, Tuple + import yaml +from kubernetes import client, config -class PolicyParser: - """Handles parsing and validation of Kubernetes NetworkPolicies""" - def __init__(self): +class PolicyParser: + def __init__(self) -> None: # Load kubernetes configuration try: config.load_kube_config() - except: + except Exception: config.load_incluster_config() self.api = client.NetworkingV1Api() def load_policy_file(self, filename: str) -> List[dict]: - """Load NetworkPolicies from a YAML file, supporting multiple documents""" - with open(filename, 'r') as f: + with open(filename, "r") as f: return list(yaml.safe_load_all(f)) def get_namespace_policies(self, namespace: str) -> List[dict]: @@ -37,17 +36,19 @@ def validate_policy(self, policy_file: str) -> Tuple[bool, str]: if not doc: # Skip empty documents continue - if doc.get('kind') == 'NetworkPolicy': + if doc.get("kind") == "NetworkPolicy": # Validate NetworkPolicy issues = self._validate_network_policy(doc) all_issues.extend(issues) - elif doc.get('kind') == 'Namespace': + elif doc.get("kind") == "Namespace": # Basic namespace validation - if 'metadata' not in doc or 'name' not in doc['metadata']: - all_issues.append("Namespace missing required metadata.name field") - elif doc.get('kind') == 'Pod': + if "metadata" not in doc or "name" not in doc["metadata"]: + all_issues.append( + "Namespace missing required metadata.name field" + ) + elif doc.get("kind") == "Pod": # Basic pod validation - if 'metadata' not in doc or 'name' not in doc['metadata']: + if "metadata" not in doc or "name" not in doc["metadata"]: all_issues.append("Pod missing required metadata.name field") if all_issues: @@ -64,29 +65,29 @@ def _validate_network_policy(self, policy: dict) -> List[str]: issues = [] # Check required fields - required_fields = ['apiVersion', 'kind', 'metadata', 'spec'] + required_fields = ["apiVersion", "kind", "metadata", "spec"] for field in required_fields: if field not in policy: issues.append(f"Missing required field: {field}") if not issues: # Only continue if basic structure is valid - spec = policy.get('spec', {}) + spec = policy.get("spec", {}) if not spec: issues.append("Empty spec in NetworkPolicy") return issues # Check podSelector exists - if 'podSelector' not in spec: + if "podSelector" not in spec: issues.append("Missing podSelector in spec") # Validate ingress rules - if 'ingress' in spec: - ingress_issues = self._validate_ingress_rules(spec['ingress']) + if "ingress" in spec: + ingress_issues = self._validate_ingress_rules(spec["ingress"]) issues.extend(ingress_issues) # Validate egress rules - if 'egress' in spec: - egress_issues = self._validate_egress_rules(spec['egress']) + if "egress" in spec: + egress_issues = self._validate_egress_rules(spec["egress"]) issues.extend(egress_issues) return issues @@ -97,17 +98,32 @@ def _validate_ingress_rules(self, rules: List[dict]) -> List[str]: for i, rule in enumerate(rules, 1): # Validate ports - if 'ports' in rule: - for port in rule.get('ports', []): - if 'port' not in port: - issues.append(f"Ingress rule {i}: Port specification missing port number") - if 'protocol' in port and port['protocol'] not in ['TCP', 'UDP', 'SCTP']: - issues.append(f"Ingress rule {i}: Invalid protocol {port['protocol']}") + if "ports" in rule: + for port in rule.get("ports", []): + if "port" not in port: + issues.append( + f"Ingress rule {i}: Port specification missing port number" + ) + if "protocol" in port and port["protocol"] not in [ + "TCP", + "UDP", + "SCTP", + ]: + issues.append( + f"Ingress rule {i}: Invalid protocol {port['protocol']}" + ) # Validate from section - if 'from' in rule: - for peer in rule['from']: - if not any(k in peer for k in ['podSelector', 'namespaceSelector', 'ipBlock']): + if "from" in rule: + for peer in rule["from"]: + if not any( + k in peer + for k in [ + "podSelector", + "namespaceSelector", + "ipBlock", + ] + ): issues.append(f"Ingress rule {i}: Peer missing selector") return issues @@ -118,17 +134,32 @@ def _validate_egress_rules(self, rules: List[dict]) -> List[str]: for i, rule in enumerate(rules, 1): # Validate ports - if 'ports' in rule: - for port in rule.get('ports', []): - if 'port' not in port: - issues.append(f"Egress rule {i}: Port specification missing port number") - if 'protocol' in port and port['protocol'] not in ['TCP', 'UDP', 'SCTP']: - issues.append(f"Egress rule {i}: Invalid protocol {port['protocol']}") + if "ports" in rule: + for port in rule.get("ports", []): + if "port" not in port: + issues.append( + f"Egress rule {i}: Port specification missing port number" + ) + if "protocol" in port and port["protocol"] not in [ + "TCP", + "UDP", + "SCTP", + ]: + issues.append( + f"Egress rule {i}: Invalid protocol {port['protocol']}" + ) # Validate to section - if 'to' in rule: - for peer in rule['to']: - if not any(k in peer for k in ['podSelector', 'namespaceSelector', 'ipBlock']): + if "to" in rule: + for peer in rule["to"]: + if not any( + k in peer + for k in [ + "podSelector", + "namespaceSelector", + "ipBlock", + ] + ): issues.append(f"Egress rule {i}: Peer missing selector") - return issues \ No newline at end of file + return issues diff --git a/src/knetvis/simulator.py b/src/knetvis/simulator.py new file mode 100644 index 0000000..32d91c8 --- /dev/null +++ b/src/knetvis/simulator.py @@ -0,0 +1,223 @@ +from typing import List + +from kubernetes import client + +from .models import Target +from .policy import PolicyParser + + +class TrafficSimulator: + def __init__(self, policy_parser: PolicyParser) -> None: + self.policy_parser = policy_parser + self.core_api = client.CoreV1Api() + + def check_resource_exists(self, target: "Target") -> bool: + """Check if a pod exists in the specified namespace""" + try: + self.core_api.read_namespaced_pod(target.name, target.namespace) + return True + except client.exceptions.ApiException as e: + if e.status == 404: + return False + raise e + + def test_connectivity(self, source: "Target", dest: "Target") -> bool: + try: + source_policies = self.policy_parser.get_namespace_policies( + source.namespace + ) + dest_policies = self.policy_parser.get_namespace_policies(dest.namespace) + + # If no policies affect either pod, traffic is allowed + source_affected = self._policies_affect_pod(source_policies, source) + dest_affected = self._policies_affect_pod(dest_policies, dest) + + if not source_affected and not dest_affected: + return True + + egress_allowed = self._check_egress_policies(source, dest, source_policies) + ingress_allowed = self._check_ingress_policies(source, dest, dest_policies) + + print(f"Source affected: {source_affected}") + print(f"Dest affected: {dest_affected}") + print(f"Egress allowed: {egress_allowed}") + print(f"Ingress allowed: {ingress_allowed}") + + return egress_allowed and ingress_allowed + + except Exception as e: + raise Exception(f"Failed to test connectivity: {str(e)}") + + def _policies_affect_pod(self, policies: List[dict], target: "Target") -> bool: + for policy in policies: + spec = policy.get("spec", {}) + pod_selector = spec.get("pod_selector", {}) or spec.get("podSelector", {}) + if self._matches_selector(target, pod_selector): + return True + return False + + def _check_egress_policies( + self, source: "Target", dest: "Target", policies: List[dict] + ) -> bool: + matching_policies = [ + p + for p in policies + if self._matches_selector( + source, + p["spec"].get("pod_selector", {}) or p["spec"].get("podSelector", {}), + ) + ] + + if not matching_policies: + return True + + for policy in matching_policies: + if self._policy_allows_egress(policy, source, dest): + return True + return False + + def _check_ingress_policies( + self, source: "Target", dest: "Target", policies: List[dict] + ) -> bool: + matching_policies = [ + p + for p in policies + if self._matches_selector( + dest, + p["spec"].get("pod_selector", {}) or p["spec"].get("podSelector", {}), + ) + ] + + if not matching_policies: + return True + + for policy in matching_policies: + if self._policy_allows_ingress(policy, source, dest): + return True + return False + + def _matches_selector(self, target: "Target", selector: dict) -> bool: + try: + obj = self.core_api.read_namespaced_pod(target.name, target.namespace) + pod_labels = obj.metadata.labels or {} + + match_labels = selector.get("match_labels", {}) or selector.get( + "matchLabels", {} + ) + match_expressions = selector.get("match_expressions", []) or selector.get( + "matchExpressions", [] + ) + + if match_labels and not all( + pod_labels.get(k) == v for k, v in match_labels.items() + ): + return False + + for expr in match_expressions: + key = expr["key"] + operator = expr["operator"] + values = expr.get("values", []) + + if operator == "In" and pod_labels.get(key) not in values: + return False + elif operator == "NotIn" and pod_labels.get(key) in values: + return False + elif operator == "Exists" and key not in pod_labels: + return False + elif operator == "DoesNotExist" and key in pod_labels: + return False + + return True + except Exception as e: + print(f"Error matching selector: {e}") + return False + + def _policy_allows_egress( + self, policy: dict, source: "Target", dest: "Target" + ) -> bool: + spec = policy.get("spec", {}) + if "egress" not in spec and "policyTypes" not in spec: + return True + + egress_rules = spec.get("egress", []) + if not egress_rules: + return "Egress" not in spec.get("policyTypes", []) + + for rule in egress_rules: + if self._egress_rule_matches(rule, dest): + return True + return False + + def _policy_allows_ingress( + self, policy: dict, source: "Target", dest: "Target" + ) -> bool: + spec = policy.get("spec", {}) + if "ingress" not in spec and "policyTypes" not in spec: + return True + + ingress_rules = spec.get("ingress", []) + if not ingress_rules: + return "Ingress" not in spec.get("policyTypes", []) + + for rule in ingress_rules: + if self._ingress_rule_matches(rule, source): + return True + return False + + def _egress_rule_matches(self, rule: dict, dest: "Target") -> bool: + if not rule.get("to", []): + return True + + for to_peer in rule.get("to", []): + pod_selector = to_peer.get("pod_selector", {}) or to_peer.get( + "podSelector", {} + ) + if pod_selector and self._matches_selector(dest, pod_selector): + return True + + namespace_selector = to_peer.get("namespace_selector", {}) or to_peer.get( + "namespaceSelector", {} + ) + if namespace_selector: + try: + ns = self.core_api.read_namespace(dest.namespace) + ns_labels = ns.metadata.labels or {} + match_labels = namespace_selector.get( + "match_labels", {} + ) or namespace_selector.get("matchLabels", {}) + if all(ns_labels.get(k) == v for k, v in match_labels.items()): + return True + except client.exceptions.ApiException as e: + print(f"Error checking namespace: {e}") + return False + return False + + def _ingress_rule_matches(self, rule: dict, source: "Target") -> bool: + # Handle both 'from' and '_from' keys + from_peers = rule.get("from", []) or rule.get("_from", []) + if not from_peers: + return True + + for from_peer in from_peers: + pod_selector = from_peer.get("pod_selector", {}) or from_peer.get( + "podSelector", {} + ) + if pod_selector and self._matches_selector(source, pod_selector): + return True + + namespace_selector = from_peer.get( + "namespace_selector", {} + ) or from_peer.get("namespaceSelector", {}) + if namespace_selector: + try: + ns = self.core_api.read_namespace(source.namespace) + ns_labels = ns.metadata.labels or {} + match_labels = namespace_selector.get( + "match_labels", {} + ) or namespace_selector.get("matchLabels", {}) + if all(ns_labels.get(k) == v for k, v in match_labels.items()): + return True + except client.exceptions.ApiException as e: + print(f"Error checking namespace: {e}") + return False + return False diff --git a/src/knetvis/visualizer.py b/src/knetvis/visualizer.py new file mode 100644 index 0000000..91ae6ac --- /dev/null +++ b/src/knetvis/visualizer.py @@ -0,0 +1,387 @@ +# src/visualzer.py +from dataclasses import dataclass +from typing import Dict, List, Set + +import matplotlib.pyplot as plt +import networkx as nx +from kubernetes import client +from rich.console import Console + +console = Console() + + +@dataclass(frozen=True) +class NetworkNode: + name: str + kind: str + namespace: str + labels: Dict[str, str] + + def __hash__(self) -> int: + return hash((self.name, self.namespace)) + + +class NetworkVisualizer: + def __init__(self) -> None: + self.graph = nx.DiGraph() + self.core_api = client.CoreV1Api() + self.colors: Dict[str, str] = { + "pod": "#4299E1", + "namespace": "#48BB78", + "ipblock": "#F6AD55", + "allow": "#48BB78", + "deny": "#F56565", + } + + def create_graph(self, namespace: str, policies: List[dict]) -> None: + self.namespace = namespace + self.graph.clear() + self._add_namespace_pods(namespace) + for policy in policies: + self._add_policy_to_graph(policy) + + nodes_count = self.graph.number_of_nodes() + edges_count = self.graph.number_of_edges() + console.print( + f"[green]Created graph with {nodes_count} nodes " + f"and {edges_count} edges[/green]" + ) + + def save_graph(self, output_file: str) -> None: + plt.figure(figsize=(12, 8)) + pos = nx.spring_layout(self.graph, k=1, iterations=50) + self._draw_nodes(pos) + self._draw_edges(pos) + self._add_labels(pos) + plt.title("Network Policy Visualization") + plt.axis("off") + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches="tight") + plt.close() + + console.print(f"[green]Network visualization saved to {output_file}[/green]") + + def _add_namespace_pods(self, namespace: str) -> None: + """Add all pods in the namespace to the graph""" + try: + pods = self.core_api.list_namespaced_pod(namespace) + console.print("\nPod Label Information:") + for pod in pods.items: + console.print( + f"Pod: {pod.metadata.name}, " f"Labels: {pod.metadata.labels}" + ) + node = NetworkNode( + name=pod.metadata.name, + kind="pod", + namespace=namespace, + labels=pod.metadata.labels or {}, + ) + self._add_node(node) + except Exception as e: + message = f"[yellow]Warning: Failed to fetch pods: {str(e)}[/yellow]" + console.print(message) + + def _add_policy_to_graph(self, policy: dict) -> None: + """Process a network policy and add its rules to the graph""" + spec = policy.get("spec", {}) + + # Get pods selected by this policy + pod_selector = spec.get("pod_selector") or spec.get("podSelector", {}) + selected_pods = self._get_selected_pods(self.namespace, pod_selector) + + pod_names = [pod.name for pod in selected_pods] + console.print(f"Selected pods: {pod_names}") + + # Process ingress rules if they exist + ingress_rules = spec.get("ingress", []) + if ingress_rules is not None: # Check if ingress rules are defined + console.print("Processing ingress rules") + for rule in ingress_rules: + self._process_ingress_rule(rule, selected_pods) + + # Process egress rules if they exist + egress_rules = spec.get("egress", []) + if egress_rules is not None: # Check if egress rules are defined + console.print("Processing egress rules") + for rule in egress_rules: + self._process_egress_rule(rule, selected_pods) + + def _get_selected_pods(self, namespace: str, selector: dict) -> Set[NetworkNode]: + """Get pods that match a label selector""" + try: + label_selector = self._build_label_selector(selector) + pods = self.core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) + + selected = { + NetworkNode( + name=pod.metadata.name, + kind="pod", + namespace=namespace, + labels=pod.metadata.labels or {}, + ) + for pod in pods.items + } + + pod_names = [pod.name for pod in selected] + console.print(f"Found matching pods: {pod_names}") + return selected + + except Exception as e: + msg = f"[yellow]Warning: Failed to get selected pods: " f"{str(e)}[/yellow]" + console.print(msg) + return set() + + def _process_ingress_rule(self, rule: dict, target_pods: Set[NetworkNode]) -> None: + from_peers = rule.get("from", []) or rule.get("_from", []) + if not from_peers: + return + + for from_peer in from_peers: + ns_selector = from_peer.get("namespace_selector") or from_peer.get( + "namespaceSelector" + ) + pod_selector = from_peer.get("pod_selector") or from_peer.get("podSelector") + + try: + # When we have both selectors in same peer (AND condition) + if ns_selector and pod_selector: + self._handle_dual_selector(ns_selector, pod_selector, target_pods) + # Handle single namespace selector + elif ns_selector: + self._handle_namespace_selector(ns_selector, target_pods) + # Handle single pod selector + elif pod_selector: + self._handle_pod_selector(pod_selector, self.namespace, target_pods) + + except Exception as e: + msg = ( + f"[yellow]Warning: Error processing selectors: " + f"{str(e)}[/yellow]" + ) + console.print(msg) + + def _handle_dual_selector( + self, + ns_selector: dict, + pod_selector: dict, + target_pods: Set[NetworkNode], + ) -> None: + """Handle both namespace and pod selectors""" + ns_label_selector = self._build_label_selector(ns_selector) + namespaces = self.core_api.list_namespace(label_selector=ns_label_selector) + + ns_names = [ns.metadata.name for ns in namespaces.items] + console.print(f"Found namespaces matching selector: {ns_names}") + + for ns in namespaces.items: + pod_label_selector = self._build_label_selector(pod_selector) + pods = self.core_api.list_namespaced_pod( + ns.metadata.name, label_selector=pod_label_selector + ) + console.print(f"Checking pods in namespace {ns.metadata.name}") + + for pod in pods.items: + source = NetworkNode( + name=pod.metadata.name, + kind="pod", + namespace=ns.metadata.name, + labels=pod.metadata.labels or {}, + ) + self._add_node(source) + for target in target_pods: + console.print( + f"Adding edge: {source.namespace}/{source.name} " + f"-> {target.namespace}/{target.name}" + ) + self._add_edge(source, target, "allow") + + def _handle_namespace_selector( + self, ns_selector: dict, target_pods: Set[NetworkNode] + ) -> None: + """Handle namespace selector only""" + label_selector = self._build_label_selector(ns_selector) + namespaces = self.core_api.list_namespace(label_selector=label_selector) + for ns in namespaces.items: + source = NetworkNode( + name=ns.metadata.name, + kind="namespace", + namespace="", + labels=ns.metadata.labels or {}, + ) + self._add_node(source) + for target in target_pods: + console.print(f"Adding namespace edge: {source.name} -> {target.name}") + self._add_edge(source, target, "allow") + + def _handle_pod_selector( + self, + pod_selector: dict, + namespace: str, + target_pods: Set[NetworkNode], + ) -> None: + """Handle pod selector only""" + source_pods = self._get_selected_pods(namespace, pod_selector) + for source in source_pods: + for target in target_pods: + console.print(f"Adding edge: {source.name} -> {target.name}") + self._add_node(source) + self._add_edge(source, target, "allow") + + def _process_egress_rule(self, rule: dict, source_pods: Set[NetworkNode]) -> None: + """Process an egress rule and add relevant edges""" + pod_names = [pod.name for pod in source_pods] + console.print(f"Processing egress rule for sources: {pod_names}") + + for to_peer in rule.get("to", []): + target_pods = self._get_pods_from_peer(to_peer) + target_names = [pod.name for pod in target_pods] + console.print(f"Found target pods: {target_names}") + + for source in source_pods: + for target in target_pods: + console.print(f"Adding edge: {source.name} -> {target.name}") + self._add_node(target) + self._add_edge(source, target, "allow") + + def _get_pods_from_peer(self, peer: dict) -> Set[NetworkNode]: + """Get pods that match both namespace and pod selectors""" + pods = set() + + pod_selector = peer.get("pod_selector") or peer.get("podSelector") + ns_selector = peer.get("namespace_selector") or peer.get("namespaceSelector") + + try: + if ns_selector and pod_selector: + pods = self._get_pods_with_dual_selector(ns_selector, pod_selector) + elif ns_selector: + pods = self._get_pods_with_ns_selector(ns_selector) + elif pod_selector: + pods = self._get_selected_pods(self.namespace, pod_selector) + + except Exception as e: + msg = ( + f"[yellow]Warning: Error getting pods from peer: " f"{str(e)}[/yellow]" + ) + console.print(msg) + + return pods + + def _get_pods_with_dual_selector( + self, ns_selector: dict, pod_selector: dict + ) -> Set[NetworkNode]: + """Get pods matching both namespace and pod selectors""" + pods = set() + ns_label_selector = self._build_label_selector(ns_selector) + namespaces = self.core_api.list_namespace(label_selector=ns_label_selector) + + pod_label_selector = self._build_label_selector(pod_selector) + for ns in namespaces.items: + ns_pods = self.core_api.list_namespaced_pod( + ns.metadata.name, label_selector=pod_label_selector + ) + for pod in ns_pods.items: + pods.add( + NetworkNode( + name=pod.metadata.name, + kind="pod", + namespace=ns.metadata.name, + labels=pod.metadata.labels or {}, + ) + ) + pod_names = [p.name for p in pods] + console.print(f"Found pods in namespace {ns.metadata.name}: {pod_names}") + + return pods + + def _get_pods_with_ns_selector(self, ns_selector: dict) -> Set[NetworkNode]: + """Get pods using namespace selector only""" + pods = set() + label_selector = self._build_label_selector(ns_selector) + namespaces = self.core_api.list_namespace(label_selector=label_selector) + for ns in namespaces.items: + pods.add( + NetworkNode( + name=ns.metadata.name, + kind="namespace", + namespace="", + labels=ns.metadata.labels or {}, + ) + ) + return pods + + def _build_label_selector(self, selector: dict) -> str: + """Build a label selector string from a selector dict""" + if not selector: + return "" + + parts = [] + match_labels = selector.get("match_labels") or selector.get("matchLabels", {}) + for key, value in match_labels.items(): + parts.append(f"{key}={value}") + + return ",".join(parts) + + def _add_node(self, node: NetworkNode) -> None: + """Add a node to the graph if it doesn't exist""" + node_id = f"{node.namespace}/{node.name}" + if node_id not in self.graph: + self.graph.add_node( + node_id, + kind=node.kind, + namespace=node.namespace, + labels=node.labels, + ) + + def _add_edge( + self, source: NetworkNode, target: NetworkNode, policy_type: str + ) -> None: + """Add an edge between nodes""" + source_id = f"{source.namespace}/{source.name}" + target_id = f"{target.namespace}/{target.name}" + self.graph.add_edge(source_id, target_id, type=policy_type) + + def _draw_nodes(self, pos: dict) -> None: + """Draw nodes with different colors based on type""" + for kind in ["pod", "namespace", "ipblock"]: + nodes = [n for n, d in self.graph.nodes(data=True) if d["kind"] == kind] + if nodes: + nx.draw_networkx_nodes( + self.graph, + pos, + nodelist=nodes, + node_color=self.colors[kind], + node_size=1000, + alpha=0.8, + ) + + def _draw_edges(self, pos: dict) -> None: + """Draw edges with different colors based on policy type""" + for policy_type in ["allow", "deny"]: + edges = [ + (u, v) + for u, v, d in self.graph.edges(data=True) + if d["type"] == policy_type + ] + if edges: + nx.draw_networkx_edges( + self.graph, + pos, + edgelist=edges, + edge_color=self.colors[policy_type], + arrows=True, + arrowsize=20, + ) + + def _add_labels(self, pos: dict) -> None: + """Add labels to nodes""" + labels = {} + for node in self.graph.nodes(): + name = node.split("/")[-1] + kind = self.graph.nodes[node]["kind"] + labels[node] = f"{kind}\n{name}" + + nx.draw_networkx_labels( + self.graph, pos, labels, font_size=8, font_weight="bold" + ) diff --git a/src/main.py b/src/main.py deleted file mode 100644 index b7faefd..0000000 --- a/src/main.py +++ /dev/null @@ -1,110 +0,0 @@ -# src/main.py -import os -import click -from rich.console import Console -from .policy import PolicyParser -from .simulator import TrafficSimulator -from .visualizer import NetworkVisualizer - -console = Console() - -class Target: - """Represents a Kubernetes resource target""" - def __init__(self, target_str: str): - parts = target_str.split('/') - if len(parts) == 2: - self.namespace = 'default' - self.kind, self.name = parts - elif len(parts) == 3: - self.namespace, self.kind, self.name = parts - else: - raise ValueError(f"Invalid target format: {target_str}") - -@click.group() -def cli(): - """knetvis - Kubernetes Network Policy Visualization Tool""" - pass - -@cli.command() -@click.argument('namespace') -@click.option('--output', '-o', default='network.png', help='Output file for visualization') -def visualize(namespace: str, output: str): - """Visualize network policies in a namespace""" - try: - parser = PolicyParser() - visualizer = NetworkVisualizer() - - # Get policies in namespace - policies = parser.get_namespace_policies(namespace) - - # Check if namespace exists and has policies - if not policies: - console.print(f"[yellow]Warning: No network policies found in namespace '{namespace}'[/yellow]") - - # Create and save visualization - visualizer.create_graph(namespace, policies) - visualizer.save_graph(output) - - console.print(f"[green]Network visualization saved to {output}[/green]") - except client.exceptions.ApiException as e: - if e.status == 404: - console.print(f"[red]Error: Namespace '{namespace}' not found[/red]") - else: - console.print(f"[red]Error: {str(e)}[/red]") - except Exception as e: - console.print(f"[red]Error: {str(e)}[/red]") - -@cli.command() -@click.argument('source') -@click.argument('destination') -def test(source: str, destination: str): - """Test connectivity between two resources""" - try: - source_target = Target(source) - dest_target = Target(destination) - - parser = PolicyParser() - simulator = TrafficSimulator(parser) - - # First check if resources exist - if not simulator.check_resource_exists(source_target): - console.print(f"[red]Error: Source resource {source} not found[/red]") - return - if not simulator.check_resource_exists(dest_target): - console.print(f"[red]Error: Destination resource {destination} not found[/red]") - return - - allowed = simulator.test_connectivity(source_target, dest_target) - - if allowed: - console.print("[green]✓ Traffic is allowed[/green]") - else: - console.print("[red]✗ Traffic is blocked[/red]") - - except Exception as e: - console.print(f"[red]Error: {str(e)}[/red]") - -@cli.command() -@click.argument('policy-file') -def validate(policy_file: str): - """Validate a network policy file""" - try: - # First check if file exists - if not os.path.exists(policy_file): - console.print(f"[red]Error: File '{policy_file}' does not exist[/red]") - return - - parser = PolicyParser() - is_valid, message = parser.validate_policy(policy_file) - - if is_valid: - console.print("[green]✓ Policy is valid[/green]") - else: - console.print("[yellow]Policy has potential issues:[/yellow]") - console.print(f"[yellow]{message}[/yellow]") - - except Exception as e: - console.print(f"[red]Error: {str(e)}[/red]") - -if __name__ == '__main__': - cli() \ No newline at end of file diff --git a/src/simulator.py b/src/simulator.py deleted file mode 100644 index 690b2d7..0000000 --- a/src/simulator.py +++ /dev/null @@ -1,176 +0,0 @@ -# src/simulator.py -from typing import List, Dict -from kubernetes import client -from .policy import PolicyParser - -class TrafficSimulator: - def __init__(self, policy_parser: PolicyParser): - self.policy_parser = policy_parser - self.core_api = client.CoreV1Api() - - def check_resource_exists(self, target: 'Target') -> bool: - """Check if a pod exists in the specified namespace""" - try: - self.core_api.read_namespaced_pod(target.name, target.namespace) - return True - except client.exceptions.ApiException as e: - if e.status == 404: - return False - raise e - - def test_connectivity(self, source: 'Target', dest: 'Target') -> bool: - try: - source_policies = self.policy_parser.get_namespace_policies(source.namespace) - dest_policies = self.policy_parser.get_namespace_policies(dest.namespace) - - # If no policies affect either pod, traffic is allowed - source_affected = self._policies_affect_pod(source_policies, source) - dest_affected = self._policies_affect_pod(dest_policies, dest) - - if not source_affected and not dest_affected: - return True - - egress_allowed = self._check_egress_policies(source, dest, source_policies) - ingress_allowed = self._check_ingress_policies(source, dest, dest_policies) - - print(f"Source affected: {source_affected}") - print(f"Dest affected: {dest_affected}") - print(f"Egress allowed: {egress_allowed}") - print(f"Ingress allowed: {ingress_allowed}") - - return egress_allowed and ingress_allowed - - except Exception as e: - raise Exception(f"Failed to test connectivity: {str(e)}") - - def _policies_affect_pod(self, policies: List[dict], target: 'Target') -> bool: - for policy in policies: - spec = policy.get('spec', {}) - pod_selector = spec.get('pod_selector', {}) or spec.get('podSelector', {}) - if self._matches_selector(target, pod_selector): - return True - return False - - def _check_egress_policies(self, source: 'Target', dest: 'Target', policies: List[dict]) -> bool: - matching_policies = [p for p in policies if self._matches_selector(source, p['spec'].get('pod_selector', {}) or p['spec'].get('podSelector', {}))] - - if not matching_policies: - return True - - for policy in matching_policies: - if self._policy_allows_egress(policy, source, dest): - return True - return False - - def _check_ingress_policies(self, source: 'Target', dest: 'Target', policies: List[dict]) -> bool: - matching_policies = [p for p in policies if self._matches_selector(dest, p['spec'].get('pod_selector', {}) or p['spec'].get('podSelector', {}))] - - if not matching_policies: - return True - - for policy in matching_policies: - if self._policy_allows_ingress(policy, source, dest): - return True - return False - - def _matches_selector(self, target: 'Target', selector: dict) -> bool: - try: - obj = self.core_api.read_namespaced_pod(target.name, target.namespace) - pod_labels = obj.metadata.labels or {} - - match_labels = selector.get('match_labels', {}) or selector.get('matchLabels', {}) - match_expressions = selector.get('match_expressions', []) or selector.get('matchExpressions', []) - - if match_labels and not all(pod_labels.get(k) == v for k, v in match_labels.items()): - return False - - for expr in match_expressions: - key = expr['key'] - operator = expr['operator'] - values = expr.get('values', []) - - if operator == 'In' and pod_labels.get(key) not in values: - return False - elif operator == 'NotIn' and pod_labels.get(key) in values: - return False - elif operator == 'Exists' and key not in pod_labels: - return False - elif operator == 'DoesNotExist' and key in pod_labels: - return False - - return True - except Exception as e: - print(f"Error matching selector: {e}") - return False - - def _policy_allows_egress(self, policy: dict, source: 'Target', dest: 'Target') -> bool: - spec = policy.get('spec', {}) - if 'egress' not in spec and 'policyTypes' not in spec: - return True - - egress_rules = spec.get('egress', []) - if not egress_rules: - return 'Egress' not in spec.get('policyTypes', []) - - for rule in egress_rules: - if self._egress_rule_matches(rule, dest): - return True - return False - - def _policy_allows_ingress(self, policy: dict, source: 'Target', dest: 'Target') -> bool: - spec = policy.get('spec', {}) - if 'ingress' not in spec and 'policyTypes' not in spec: - return True - - ingress_rules = spec.get('ingress', []) - if not ingress_rules: - return 'Ingress' not in spec.get('policyTypes', []) - - for rule in ingress_rules: - if self._ingress_rule_matches(rule, source): - return True - return False - - def _egress_rule_matches(self, rule: dict, dest: 'Target') -> bool: - if not rule.get('to', []): - return True - - for to_peer in rule.get('to', []): - pod_selector = to_peer.get('pod_selector', {}) or to_peer.get('podSelector', {}) - if pod_selector and self._matches_selector(dest, pod_selector): - return True - - namespace_selector = to_peer.get('namespace_selector', {}) or to_peer.get('namespaceSelector', {}) - if namespace_selector: - try: - ns = self.core_api.read_namespace(dest.namespace) - ns_labels = ns.metadata.labels or {} - match_labels = namespace_selector.get('match_labels', {}) or namespace_selector.get('matchLabels', {}) - if all(ns_labels.get(k) == v for k, v in match_labels.items()): - return True - except: - pass - return False - - def _ingress_rule_matches(self, rule: dict, source: 'Target') -> bool: - # Handle both 'from' and '_from' keys - from_peers = rule.get('from', []) or rule.get('_from', []) - if not from_peers: - return True - - for from_peer in from_peers: - pod_selector = from_peer.get('pod_selector', {}) or from_peer.get('podSelector', {}) - if pod_selector and self._matches_selector(source, pod_selector): - return True - - namespace_selector = from_peer.get('namespace_selector', {}) or from_peer.get('namespaceSelector', {}) - if namespace_selector: - try: - ns = self.core_api.read_namespace(source.namespace) - ns_labels = ns.metadata.labels or {} - match_labels = namespace_selector.get('match_labels', {}) or namespace_selector.get('matchLabels', {}) - if all(ns_labels.get(k) == v for k, v in match_labels.items()): - return True - except: - pass - return False \ No newline at end of file diff --git a/src/visualizer.py b/src/visualizer.py deleted file mode 100644 index 1c06246..0000000 --- a/src/visualizer.py +++ /dev/null @@ -1,356 +0,0 @@ -# src/visualzer.py -from typing import List, Dict, Set, Optional -import networkx as nx -import matplotlib.pyplot as plt -from kubernetes import client -from dataclasses import dataclass -from rich.console import Console - -console = Console() - -@dataclass(frozen=True) -class NetworkNode: - """Represents a node in the network graph""" - name: str - kind: str # 'pod', 'namespace', 'ipblock' - namespace: str - labels: Dict[str, str] - - def __hash__(self): - # Hash based on name and namespace which should be unique - return hash((self.name, self.namespace)) - -class NetworkVisualizer: - """Creates visual representations of network policies""" - - def __init__(self): - self.graph = nx.DiGraph() - self.core_api = client.CoreV1Api() - - # Color scheme - self.colors = { - 'pod': '#4299E1', # Blue - 'namespace': '#48BB78', # Green - 'ipblock': '#F6AD55', # Orange - 'allow': '#48BB78', # Green - 'deny': '#F56565' # Red - } - - def create_graph(self, namespace: str, policies: List[dict]): - """Create a graph representation of network policies""" - self.namespace = namespace - self.graph.clear() - - # Add nodes for all pods in the namespace - self._add_namespace_pods(namespace) - - # Process each policy - for policy in policies: - self._add_policy_to_graph(policy) - - console.print(f"[green]Created graph with {self.graph.number_of_nodes()} nodes " - f"and {self.graph.number_of_edges()} edges[/green]") - - def save_graph(self, output_file: str): - """Save the network graph visualization to a file""" - plt.figure(figsize=(12, 8)) - - # Create layout - pos = nx.spring_layout(self.graph, k=1, iterations=50) - - # Draw nodes - self._draw_nodes(pos) - - # Draw edges - self._draw_edges(pos) - - # Add labels - self._add_labels(pos) - - plt.title("Network Policy Visualization") - plt.axis('off') - plt.tight_layout() - - # Save to file - plt.savefig(output_file, dpi=300, bbox_inches='tight') - plt.close() - - console.print(f"[green]Network visualization saved to {output_file}[/green]") - - def _add_namespace_pods(self, namespace: str): - """Add all pods in the namespace to the graph""" - try: - pods = self.core_api.list_namespaced_pod(namespace) - console.print("\nPod Label Information:") - for pod in pods.items: - console.print(f"Pod: {pod.metadata.name}, Labels: {pod.metadata.labels}") - node = NetworkNode( - name=pod.metadata.name, - kind='pod', - namespace=namespace, - labels=pod.metadata.labels or {} - ) - self._add_node(node) - except Exception as e: - console.print(f"[yellow]Warning: Failed to fetch pods: {str(e)}[/yellow]") - - def _add_policy_to_graph(self, policy: dict): - """Process a network policy and add its rules to the graph""" - spec = policy.get('spec', {}) - - # Get pods selected by this policy - pod_selector = spec.get('pod_selector') or spec.get('podSelector', {}) - selected_pods = self._get_selected_pods(self.namespace, pod_selector) - - console.print(f"Selected pods: {[pod.name for pod in selected_pods]}") - - # Process ingress rules if they exist - ingress_rules = spec.get('ingress', []) - if ingress_rules is not None: # Check if ingress rules are defined - console.print("Processing ingress rules") - for rule in ingress_rules: - self._process_ingress_rule(rule, selected_pods) - - # Process egress rules if they exist - egress_rules = spec.get('egress', []) - if egress_rules is not None: # Check if egress rules are defined - console.print("Processing egress rules") - for rule in egress_rules: - self._process_egress_rule(rule, selected_pods) - - def _get_selected_pods(self, namespace: str, selector: dict) -> Set[NetworkNode]: - """Get pods that match a label selector""" - try: - label_selector = self._build_label_selector(selector) - pods = self.core_api.list_namespaced_pod( - namespace, - label_selector=label_selector - ) - - selected = { - NetworkNode( - name=pod.metadata.name, - kind='pod', - namespace=namespace, - labels=pod.metadata.labels or {} - ) - for pod in pods.items - } - console.print(f"Found matching pods: {[pod.name for pod in selected]}") - return selected - - except Exception as e: - console.print(f"[yellow]Warning: Failed to get selected pods: {str(e)}[/yellow]") - return set() - - def _process_ingress_rule(self, rule: dict, target_pods: Set[NetworkNode]): - from_peers = rule.get('from', []) or rule.get('_from', []) - if not from_peers: - return - - for from_peer in from_peers: - namespace_selector = from_peer.get('namespace_selector') or from_peer.get('namespaceSelector') - pod_selector = from_peer.get('pod_selector') or from_peer.get('podSelector') - - try: - # When we have both selectors in same peer (AND condition) - if namespace_selector and pod_selector: - ns_label_selector = self._build_label_selector(namespace_selector) - namespaces = self.core_api.list_namespace(label_selector=ns_label_selector) - console.print(f"Found namespaces matching selector: {[ns.metadata.name for ns in namespaces.items]}") - - # For each matching namespace, find matching pods - for ns in namespaces.items: - pod_label_selector = self._build_label_selector(pod_selector) - pods = self.core_api.list_namespaced_pod( - ns.metadata.name, - label_selector=pod_label_selector - ) - console.print(f"Checking pods in namespace {ns.metadata.name}") - - # Add edges for each matching pod - for pod in pods.items: - source = NetworkNode( - name=pod.metadata.name, - kind='pod', - namespace=ns.metadata.name, - labels=pod.metadata.labels or {} - ) - self._add_node(source) - for target in target_pods: - console.print(f"Adding edge: {source.namespace}/{source.name} -> {target.namespace}/{target.name}") - self._add_edge(source, target, 'allow') - - # Handle single namespace selector - elif namespace_selector: - label_selector = self._build_label_selector(namespace_selector) - namespaces = self.core_api.list_namespace(label_selector=label_selector) - for ns in namespaces.items: - source = NetworkNode( - name=ns.metadata.name, - kind='namespace', - namespace='', - labels=ns.metadata.labels or {} - ) - self._add_node(source) - for target in target_pods: - console.print(f"Adding namespace edge: {source.name} -> {target.name}") - self._add_edge(source, target, 'allow') - - # Handle single pod selector - elif pod_selector: - source_pods = self._get_selected_pods(self.namespace, pod_selector) - for source in source_pods: - for target in target_pods: - console.print(f"Adding edge: {source.name} -> {target.name}") - self._add_node(source) - self._add_edge(source, target, 'allow') - - except Exception as e: - console.print(f"[yellow]Warning: Error processing selectors: {str(e)}[/yellow]") - - def _process_egress_rule(self, rule: dict, source_pods: Set[NetworkNode]): - """Process an egress rule and add relevant edges""" - console.print(f"Processing egress rule for sources: {[pod.name for pod in source_pods]}") - - for to_peer in rule.get('to', []): - target_pods = self._get_pods_from_peer(to_peer) - console.print(f"Found target pods: {[pod.name for pod in target_pods]}") - - for source in source_pods: - for target in target_pods: - console.print(f"Adding edge: {source.name} -> {target.name}") - self._add_node(target) - self._add_edge(source, target, 'allow') - - def _get_pods_from_peer(self, peer: dict) -> Set[NetworkNode]: - """Get pods that match both namespace and pod selectors when specified""" - pods = set() - - # Handle snake_case vs camelCase - pod_selector = peer.get('pod_selector') or peer.get('podSelector') - namespace_selector = peer.get('namespace_selector') or peer.get('namespaceSelector') - - try: - # If we have both selectors, we need pods matching both conditions - if namespace_selector and pod_selector: - # First get matching namespaces - ns_label_selector = self._build_label_selector(namespace_selector) - namespaces = self.core_api.list_namespace(label_selector=ns_label_selector) - - # Then for each matching namespace, get matching pods - pod_label_selector = self._build_label_selector(pod_selector) - for ns in namespaces.items: - ns_pods = self.core_api.list_namespaced_pod( - ns.metadata.name, - label_selector=pod_label_selector - ) - for pod in ns_pods.items: - pods.add(NetworkNode( - name=pod.metadata.name, - kind='pod', - namespace=ns.metadata.name, - labels=pod.metadata.labels or {} - )) - console.print(f"Found pods in namespace {ns.metadata.name}: {[p.name for p in pods]}") - - # If just namespace selector - elif namespace_selector: - label_selector = self._build_label_selector(namespace_selector) - namespaces = self.core_api.list_namespace(label_selector=label_selector) - for ns in namespaces.items: - pods.add(NetworkNode( - name=ns.metadata.name, - kind='namespace', - namespace='', - labels=ns.metadata.labels or {} - )) - - # If just pod selector - elif pod_selector: - pods.update(self._get_selected_pods(self.namespace, pod_selector)) - - except Exception as e: - console.print(f"[yellow]Warning: Error getting pods from peer: {str(e)}[/yellow]") - - return pods - - def _build_label_selector(self, selector: dict) -> str: - """Build a label selector string from a selector dict""" - if not selector: - return '' - - parts = [] - # Handle snake_case match_labels - match_labels = selector.get('match_labels') or selector.get('matchLabels', {}) - for key, value in match_labels.items(): - parts.append(f"{key}={value}") - - return ','.join(parts) - - def _add_node(self, node: NetworkNode): - """Add a node to the graph if it doesn't exist""" - node_id = f"{node.namespace}/{node.name}" - if node_id not in self.graph: - self.graph.add_node( - node_id, - kind=node.kind, - namespace=node.namespace, - labels=node.labels - ) - - def _add_edge(self, source: NetworkNode, target: NetworkNode, policy_type: str): - """Add an edge between nodes""" - source_id = f"{source.namespace}/{source.name}" - target_id = f"{target.namespace}/{target.name}" - - self.graph.add_edge( - source_id, - target_id, - type=policy_type - ) - - def _draw_nodes(self, pos): - """Draw nodes with different colors based on type""" - for kind in ['pod', 'namespace', 'ipblock']: - nodes = [n for n, d in self.graph.nodes(data=True) if d['kind'] == kind] - if nodes: - nx.draw_networkx_nodes( - self.graph, - pos, - nodelist=nodes, - node_color=self.colors[kind], - node_size=1000, - alpha=0.8 - ) - - def _draw_edges(self, pos): - """Draw edges with different colors based on policy type""" - for policy_type in ['allow', 'deny']: - edges = [(u, v) for u, v, d in self.graph.edges(data=True) - if d['type'] == policy_type] - if edges: - nx.draw_networkx_edges( - self.graph, - pos, - edgelist=edges, - edge_color=self.colors[policy_type], - arrows=True, - arrowsize=20 - ) - - def _add_labels(self, pos): - """Add labels to nodes""" - labels = {} - for node in self.graph.nodes(): - name = node.split('/')[-1] - kind = self.graph.nodes[node]['kind'] - labels[node] = f"{kind}\n{name}" - - nx.draw_networkx_labels( - self.graph, - pos, - labels, - font_size=8, - font_weight='bold' - ) \ No newline at end of file diff --git a/tests/test_policy.py b/tests/test_policy.py index 809887d..6887bd6 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import Mock, patch -from src.policy import PolicyParser + +from src.knetvis.policy import PolicyParser + def test_load_policy_file(tmp_path): # Create a test policy file @@ -26,23 +27,28 @@ def test_load_policy_file(tmp_path): policy_file.write_text(policy_content) parser = PolicyParser() - policy = parser.load_policy_file(str(policy_file)) + policies = parser.load_policy_file(str(policy_file)) + + # Get the first policy since we only have one + policy = policies[0] + assert policy["metadata"]["name"] == "test-policy" + assert policy["spec"]["podSelector"]["matchLabels"]["app"] == "web" - assert policy['metadata']['name'] == 'test-policy' - assert policy['spec']['podSelector']['matchLabels']['app'] == 'web' -@patch('kubernetes.client.NetworkingV1Api') +@patch("kubernetes.client.NetworkingV1Api") def test_get_namespace_policies(mock_api): # Mock the kubernetes API response mock_policy = Mock() mock_policy.to_dict.return_value = { - 'metadata': {'name': 'test-policy'}, - 'spec': {'podSelector': {}} + "metadata": {"name": "test-policy"}, + "spec": {"podSelector": {}}, } - mock_api.return_value.list_namespaced_network_policy.return_value.items = [mock_policy] + mock_api.return_value.list_namespaced_network_policy.return_value.items = [ + mock_policy + ] parser = PolicyParser() - policies = parser.get_namespace_policies('default') + policies = parser.get_namespace_policies("default") assert len(policies) == 1 - assert policies[0]['metadata']['name'] == 'test-policy' \ No newline at end of file + assert policies[0]["metadata"]["name"] == "test-policy" diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 809887d..6887bd6 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import Mock, patch -from src.policy import PolicyParser + +from src.knetvis.policy import PolicyParser + def test_load_policy_file(tmp_path): # Create a test policy file @@ -26,23 +27,28 @@ def test_load_policy_file(tmp_path): policy_file.write_text(policy_content) parser = PolicyParser() - policy = parser.load_policy_file(str(policy_file)) + policies = parser.load_policy_file(str(policy_file)) + + # Get the first policy since we only have one + policy = policies[0] + assert policy["metadata"]["name"] == "test-policy" + assert policy["spec"]["podSelector"]["matchLabels"]["app"] == "web" - assert policy['metadata']['name'] == 'test-policy' - assert policy['spec']['podSelector']['matchLabels']['app'] == 'web' -@patch('kubernetes.client.NetworkingV1Api') +@patch("kubernetes.client.NetworkingV1Api") def test_get_namespace_policies(mock_api): # Mock the kubernetes API response mock_policy = Mock() mock_policy.to_dict.return_value = { - 'metadata': {'name': 'test-policy'}, - 'spec': {'podSelector': {}} + "metadata": {"name": "test-policy"}, + "spec": {"podSelector": {}}, } - mock_api.return_value.list_namespaced_network_policy.return_value.items = [mock_policy] + mock_api.return_value.list_namespaced_network_policy.return_value.items = [ + mock_policy + ] parser = PolicyParser() - policies = parser.get_namespace_policies('default') + policies = parser.get_namespace_policies("default") assert len(policies) == 1 - assert policies[0]['metadata']['name'] == 'test-policy' \ No newline at end of file + assert policies[0]["metadata"]["name"] == "test-policy" diff --git a/tests/test_visualizer.py b/tests/test_visualizer.py index 5f2b286..52b80b2 100644 --- a/tests/test_visualizer.py +++ b/tests/test_visualizer.py @@ -1,50 +1,47 @@ -import pytest from unittest.mock import Mock, patch -import networkx as nx -from src.visualizer import NetworkVisualizer -@patch('kubernetes.client.CoreV1Api') +from src.knetvis.visualizer import NetworkVisualizer + + +@patch("kubernetes.client.CoreV1Api") def test_create_graph(mock_core_api): # Mock pod list response mock_pod = Mock() - mock_pod.metadata.name = 'test-pod' - mock_pod.metadata.labels = {'app': 'web'} + mock_pod.metadata.name = "test-pod" + mock_pod.metadata.labels = {"app": "web"} mock_core_api.return_value.list_namespaced_pod.return_value.items = [mock_pod] visualizer = NetworkVisualizer() # Test policy that selects the pod - policies = [{ - 'metadata': {'namespace': 'default'}, - 'spec': { - 'podSelector': { - 'matchLabels': {'app': 'web'} + policies = [ + { + "metadata": {"namespace": "default"}, + "spec": { + "podSelector": {"matchLabels": {"app": "web"}}, + "ingress": [ + {"from": [{"podSelector": {"matchLabels": {"role": "frontend"}}}]} + ], }, - 'ingress': [{ - 'from': [{ - 'podSelector': { - 'matchLabels': {'role': 'frontend'} - } - }] - }] } - }] + ] - visualizer.create_graph('default', policies) + visualizer.create_graph("default", policies) # Verify graph structure assert visualizer.graph.number_of_nodes() > 0 + def test_save_graph(tmp_path): visualizer = NetworkVisualizer() # Add some test nodes and edges - visualizer.graph.add_node('pod1', kind='pod', namespace='default', labels={}) - visualizer.graph.add_node('pod2', kind='pod', namespace='default', labels={}) - visualizer.graph.add_edge('pod1', 'pod2', type='allow') + visualizer.graph.add_node("pod1", kind="pod", namespace="default", labels={}) + visualizer.graph.add_node("pod2", kind="pod", namespace="default", labels={}) + visualizer.graph.add_edge("pod1", "pod2", type="allow") # Test saving the graph output_file = str(tmp_path / "test_graph.png") visualizer.save_graph(output_file) - assert tmp_path.exists() \ No newline at end of file + assert tmp_path.exists()