Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up function get_stack_sampler by 28% #364

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

misrasaurabh1
Copy link

📄 28% (0.28x) speedup for get_stack_sampler in pyinstrument/stack_sampler.py

⏱️ Runtime : 4.20 microseconds 3.29 microseconds (best of 63 runs)

📝 Explanation and details

To optimize the given Python program for better performance, we can focus on reducing redundant operations and streamlining the logic. The get_stack_sampler function checks if the stack_sampler attribute exists in thread_locals and creates a StackSampler instance if it does not. This code is already relatively straightforward, but we can make a small improvement to slightly optimize its performance.

First, ensure that accessing the attribute and setting it is done as quickly as possible by minimizing the number of lookups. Python attribute access can be somewhat costly, so we'll reduce the number of attribute lookups slightly. Instead of using hasattr and setattr separately, we can use a single try/except block to handle it.

Here's the optimized version of the get_stack_sampler function.

This optimization minimizes attribute lookups by using a try/except block. If the stack_sampler attribute is not found, the exception is caught, and we create the StackSampler instance and set the attribute in one go.

This should slightly improve the performance of the function by reducing the overhead associated with multiple attribute lookups and method calls.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 6 Passed
🌀 Generated Regression Tests 2 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Existing Unit Tests Details
- test_stack_sampler.py
🌀 Generated Regression Tests Details
from __future__ import annotations

import sys
import textwrap
import threading
import time
import timeit
import types
from typing import Any, Callable

# imports
import pytest  # used for our unit tests
from pyinstrument.low_level.stat_profile import (setstatprofile,
                                                 walltime_coarse_resolution)
from pyinstrument.low_level.types import TimerType
from pyinstrument.stack_sampler import get_stack_sampler
from pyinstrument.util import format_float_with_sig_figs, unwrap


class StackSampler:
    """Manages setstatprofile for Profilers on a single thread"""

    subscribers: list[StackSamplerSubscriber]
    current_sampling_interval: float | None
    last_profile_time: float
    timer_func: Callable[[], float] | None
    has_warned_about_timing_overhead: bool

    def __init__(self) -> None:
        self.subscribers = []
        self.current_sampling_interval = None
        self.last_profile_time = 0.0
        self.timer_func = None
        self.has_warned_about_timing_overhead = False

    def subscribe(
        self,
        target: StackSamplerSubscriberTarget,
        *,
        desired_interval: float,
        use_timing_thread: bool | None = None,
        use_async_context: bool,
    ):
        if use_async_context:
            if active_profiler_context_var.get() is not None:
                raise RuntimeError(
                    "There is already a profiler running. You cannot run multiple profilers in the same thread or async context, unless you disable async support."
                )
            active_profiler_context_var.set(target)

        self.subscribers.append(
            StackSamplerSubscriber(
                target=target,
                desired_interval=desired_interval,
                use_timing_thread=use_timing_thread,
                bound_to_async_context=use_async_context,
                async_state=AsyncState("in_context") if use_async_context else None,
            )
        )
        self._update()

    def unsubscribe(self, target: StackSamplerSubscriberTarget):
        try:
            subscriber = next(s for s in self.subscribers if s.target == target)  # type: ignore
        except StopIteration:
            raise StackSampler.SubscriberNotFound()

        if subscriber.bound_to_async_context:
            # (don't need to use context_var.reset() because we verified it was
            # None before we started)
            active_profiler_context_var.set(None)

        self.subscribers.remove(subscriber)

        self._update()

    def _update(self):
        if len(self.subscribers) == 0:
            self._stop_sampling()
            return

        min_subscribers_interval = min(s.desired_interval for s in self.subscribers)
        timing_thread_preferences = set(
            s.use_timing_thread for s in self.subscribers if s.use_timing_thread is not None
        )
        if len(timing_thread_preferences) > 1:
            raise ValueError(
                f"Profiler requested different timing thread preferences from a profiler that is already running."
            )

        use_timing_thread = next(iter(timing_thread_preferences), False)

        if self.current_sampling_interval != min_subscribers_interval:
            self._start_sampling(
                interval=min_subscribers_interval, use_timing_thread=use_timing_thread
            )

    def _start_sampling(self, interval: float, use_timing_thread: bool):
        if use_timing_thread and self.timer_func is not None:
            raise ValueError(
                f"Profiler requested to use the timing thread but this stack sampler is already using a custom timer function."
            )

        timer_type: TimerType

        if self.timer_func:
            timer_type = "timer_func"
        elif use_timing_thread:
            timer_type = "walltime_thread"
        else:
            coarse_resolution = walltime_coarse_resolution()
            if coarse_resolution is not None and coarse_resolution <= interval:
                timer_type = "walltime_coarse"
            else:
                timer_type = "walltime"

        self._check_timing_overhead(interval=interval, timer_type=timer_type)

        self.current_sampling_interval = interval
        if self.last_profile_time == 0.0:
            self.last_profile_time = self._timer()

        setstatprofile(
            target=self._sample,
            interval=interval,
            context_var=active_profiler_context_var,
            timer_type=timer_type,
            timer_func=self.timer_func,
        )

    def _stop_sampling(self):
        setstatprofile(None)
        self.current_sampling_interval = None
        self.last_profile_time = 0.0

    def _sample(self, frame: types.FrameType, event: str, arg: Any):
        if event == "context_changed":
            new, old, coroutine_stack = arg

            for subscriber in self.subscribers:
                if subscriber.target == old:
                    full_stack = build_call_stack(frame, event, arg)
                    if coroutine_stack:
                        full_stack.extend(reversed(coroutine_stack))
                        subscriber.async_state = AsyncState(
                            "out_of_context_awaited", info=full_stack
                        )
                    else:
                        subscriber.async_state = AsyncState(
                            "out_of_context_unknown", info=full_stack
                        )
                elif subscriber.target == new:
                    subscriber.async_state = AsyncState("in_context")
        else:
            now = self._timer()
            time_since_last_sample = now - self.last_profile_time

            call_stack = build_call_stack(frame, event, arg)

            for subscriber in self.subscribers:
                subscriber.target(call_stack, time_since_last_sample, subscriber.async_state)

            self.last_profile_time = now

    def _timer(self):
        if self.timer_func:
            return self.timer_func()
        else:
            return timeit.default_timer()

    def _check_timing_overhead(self, interval: float, timer_type: TimerType):
        if self.has_warned_about_timing_overhead:
            return
        if IGNORE_OVERHEAD_WARNING:
            return

        overheads = timing_overhead()
        overhead = overheads.get(timer_type)
        if overhead is None:
            return

        if timer_type == "walltime":
            if overhead > 300e-9:
                self.has_warned_about_timing_overhead = True
                message_parts: list[str] = []
                message_parts.append(
                    f"""
                    pyinstrument: the timer on your system has an overhead of
                    {overhead * 1e9:.0f} nanoseconds, which is considered
                    high. You might experience longer runtimes than usual, and
                    programs with lots of pure-python code might be distorted.
                    """
                )

                message_parts.append(
                    f"""
                    You might want to try the timing thread option, which can
                    be enabled using --use-timing-thread at the command line,
                    or by setting the use_timing_thread parameter in the
                    Profiler constructor.
                    """
                )

                if "walltime_coarse" in overheads and overheads["walltime_coarse"] < 300e-9:
                    coarse_resolution = walltime_coarse_resolution()
                    message_parts.append(
                        f"""
                        Your system does offer a 'coarse' timer, with a lower
                        overhead ({overheads["walltime_coarse"] * 1e9:.2g}
                        nanoseconds). You can enable it by setting
                        pyinstrument's interval to a value higher than
                        {format_float_with_sig_figs(coarse_resolution,
                        trim_zeroes=True)} seconds. If you're happy with the
                        lower precision, this is the best option.
                        """
                    )

                message_parts.append(
                    f"""
                    If you want to suppress this warning, you can set the
                    environment variable PYINSTRUMENT_IGNORE_OVERHEAD_WARNING
                    to '1'.
                    """
                )

                message = "\n\n".join(
                    textwrap.fill(unwrap(part), width=80) for part in message_parts
                )

                print(message, file=sys.stderr)

    class SubscriberNotFound(Exception):
        pass

thread_locals = threading.local()
from pyinstrument.stack_sampler import get_stack_sampler

# unit tests

# Helper classes and functions for testing
class StackSamplerSubscriber:
    def __init__(self, target, desired_interval, use_timing_thread, bound_to_async_context, async_state):
        self.target = target
        self.desired_interval = desired_interval
        self.use_timing_thread = use_timing_thread
        self.bound_to_async_context = bound_to_async_context
        self.async_state = async_state

class AsyncState:
    def __init__(self, state, info=None):
        self.state = state
        self.info = info

def build_call_stack(frame, event, arg):
    return []

active_profiler_context_var = threading.local()
IGNORE_OVERHEAD_WARNING = False

def timing_overhead():
    return {"walltime": 100e-9, "walltime_coarse": 50e-9}

# Basic Test Cases








def test_no_subscribers():
    codeflash_output = get_stack_sampler()








from __future__ import annotations

import sys
import textwrap
import threading
import time
import timeit
import types
from typing import Any, Callable

# imports
import pytest  # used for our unit tests
from pyinstrument.low_level.stat_profile import (setstatprofile,
                                                 walltime_coarse_resolution)
from pyinstrument.low_level.types import TimerType
from pyinstrument.stack_sampler import get_stack_sampler
from pyinstrument.util import format_float_with_sig_figs, unwrap


class StackSampler:
    """Manages setstatprofile for Profilers on a single thread"""

    subscribers: list[StackSamplerSubscriber]
    current_sampling_interval: float | None
    last_profile_time: float
    timer_func: Callable[[], float] | None
    has_warned_about_timing_overhead: bool

    def __init__(self) -> None:
        self.subscribers = []
        self.current_sampling_interval = None
        self.last_profile_time = 0.0
        self.timer_func = None
        self.has_warned_about_timing_overhead = False

    def subscribe(
        self,
        target: StackSamplerSubscriberTarget,
        *,
        desired_interval: float,
        use_timing_thread: bool | None = None,
        use_async_context: bool,
    ):
        if use_async_context:
            if active_profiler_context_var.get() is not None:
                raise RuntimeError(
                    "There is already a profiler running. You cannot run multiple profilers in the same thread or async context, unless you disable async support."
                )
            active_profiler_context_var.set(target)

        self.subscribers.append(
            StackSamplerSubscriber(
                target=target,
                desired_interval=desired_interval,
                use_timing_thread=use_timing_thread,
                bound_to_async_context=use_async_context,
                async_state=AsyncState("in_context") if use_async_context else None,
            )
        )
        self._update()

    def unsubscribe(self, target: StackSamplerSubscriberTarget):
        try:
            subscriber = next(s for s in self.subscribers if s.target == target)  # type: ignore
        except StopIteration:
            raise StackSampler.SubscriberNotFound()

        if subscriber.bound_to_async_context:
            # (don't need to use context_var.reset() because we verified it was
            # None before we started)
            active_profiler_context_var.set(None)

        self.subscribers.remove(subscriber)

        self._update()

    def _update(self):
        if len(self.subscribers) == 0:
            self._stop_sampling()
            return

        min_subscribers_interval = min(s.desired_interval for s in self.subscribers)
        timing_thread_preferences = set(
            s.use_timing_thread for s in self.subscribers if s.use_timing_thread is not None
        )
        if len(timing_thread_preferences) > 1:
            raise ValueError(
                f"Profiler requested different timing thread preferences from a profiler that is already running."
            )

        use_timing_thread = next(iter(timing_thread_preferences), False)

        if self.current_sampling_interval != min_subscribers_interval:
            self._start_sampling(
                interval=min_subscribers_interval, use_timing_thread=use_timing_thread
            )

    def _start_sampling(self, interval: float, use_timing_thread: bool):
        if use_timing_thread and self.timer_func is not None:
            raise ValueError(
                f"Profiler requested to use the timing thread but this stack sampler is already using a custom timer function."
            )

        timer_type: TimerType

        if self.timer_func:
            timer_type = "timer_func"
        elif use_timing_thread:
            timer_type = "walltime_thread"
        else:
            coarse_resolution = walltime_coarse_resolution()
            if coarse_resolution is not None and coarse_resolution <= interval:
                timer_type = "walltime_coarse"
            else:
                timer_type = "walltime"

        self._check_timing_overhead(interval=interval, timer_type=timer_type)

        self.current_sampling_interval = interval
        if self.last_profile_time == 0.0:
            self.last_profile_time = self._timer()

        setstatprofile(
            target=self._sample,
            interval=interval,
            context_var=active_profiler_context_var,
            timer_type=timer_type,
            timer_func=self.timer_func,
        )

    def _stop_sampling(self):
        setstatprofile(None)
        self.current_sampling_interval = None
        self.last_profile_time = 0.0

    def _sample(self, frame: types.FrameType, event: str, arg: Any):
        if event == "context_changed":
            new, old, coroutine_stack = arg

            for subscriber in self.subscribers:
                if subscriber.target == old:
                    full_stack = build_call_stack(frame, event, arg)
                    if coroutine_stack:
                        full_stack.extend(reversed(coroutine_stack))
                        subscriber.async_state = AsyncState(
                            "out_of_context_awaited", info=full_stack
                        )
                    else:
                        subscriber.async_state = AsyncState(
                            "out_of_context_unknown", info=full_stack
                        )
                elif subscriber.target == new:
                    subscriber.async_state = AsyncState("in_context")
        else:
            now = self._timer()
            time_since_last_sample = now - self.last_profile_time

            call_stack = build_call_stack(frame, event, arg)

            for subscriber in self.subscribers:
                subscriber.target(call_stack, time_since_last_sample, subscriber.async_state)

            self.last_profile_time = now

    def _timer(self):
        if self.timer_func:
            return self.timer_func()
        else:
            return timeit.default_timer()

    def _check_timing_overhead(self, interval: float, timer_type: TimerType):
        if self.has_warned_about_timing_overhead:
            return
        if IGNORE_OVERHEAD_WARNING:
            return

        overheads = timing_overhead()
        overhead = overheads.get(timer_type)
        if overhead is None:
            return

        if timer_type == "walltime":
            if overhead > 300e-9:
                self.has_warned_about_timing_overhead = True
                message_parts: list[str] = []
                message_parts.append(
                    f"""
                    pyinstrument: the timer on your system has an overhead of
                    {overhead * 1e9:.0f} nanoseconds, which is considered
                    high. You might experience longer runtimes than usual, and
                    programs with lots of pure-python code might be distorted.
                    """
                )

                message_parts.append(
                    f"""
                    You might want to try the timing thread option, which can
                    be enabled using --use-timing-thread at the command line,
                    or by setting the use_timing_thread parameter in the
                    Profiler constructor.
                    """
                )

                if "walltime_coarse" in overheads and overheads["walltime_coarse"] < 300e-9:
                    coarse_resolution = walltime_coarse_resolution()
                    message_parts.append(
                        f"""
                        Your system does offer a 'coarse' timer, with a lower
                        overhead ({overheads["walltime_coarse"] * 1e9:.2g}
                        nanoseconds). You can enable it by setting
                        pyinstrument's interval to a value higher than
                        {format_float_with_sig_figs(coarse_resolution,
                        trim_zeroes=True)} seconds. If you're happy with the
                        lower precision, this is the best option.
                        """
                    )

                message_parts.append(
                    f"""
                    If you want to suppress this warning, you can set the
                    environment variable PYINSTRUMENT_IGNORE_OVERHEAD_WARNING
                    to '1'.
                    """
                )

                message = "\n\n".join(
                    textwrap.fill(unwrap(part), width=80) for part in message_parts
                )

                print(message, file=sys.stderr)

    class SubscriberNotFound(Exception):
        pass

thread_locals = threading.local()
from pyinstrument.stack_sampler import get_stack_sampler

# unit tests

# Basic Functionality Tests
def test_get_stack_sampler_no_subscribers():
    codeflash_output = get_stack_sampler()












from pyinstrument.stack_sampler import get_stack_sampler

def test_get_stack_sampler():
    assert get_stack_sampler() == <pyinstrument.stack_sampler.StackSampler object at 0x709502e7fa10>

Optimized with codeflash.ai

To optimize the given Python program for better performance, we can focus on reducing redundant operations and streamlining the logic. The `get_stack_sampler` function checks if the `stack_sampler` attribute exists in `thread_locals` and creates a `StackSampler` instance if it does not. This code is already relatively straightforward, but we can make a small improvement to slightly optimize its performance.

First, ensure that accessing the attribute and setting it is done as quickly as possible by minimizing the number of lookups. Python attribute access can be somewhat costly, so we'll reduce the number of attribute lookups slightly. Instead of using `hasattr` and `setattr` separately, we can use a single `try/except` block to handle it.

Here's the optimized version of the `get_stack_sampler` function.



This optimization minimizes attribute lookups by using a try/except block. If the `stack_sampler` attribute is not found, the exception is caught, and we create the `StackSampler` instance and set the attribute in one go.

This should slightly improve the performance of the function by reducing the overhead associated with multiple attribute lookups and method calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant