Python threading.local() 解說

Posted on  Sep 4, 2023  in  Python 程式設計 - 高階  by  Amo Chen  ‐ 3 min read

我們都知道多個執行緒(thread)之間會共用 Process 的記憶體,那你覺得以下範例程式的執行結果會是什麼呢?這是 2 個執行緒分別做 +1 與 -1 運算各 100,000 次的 Python 程式:

import threading


def count(thread_name, step=1):
    global v
    for i in range(0, 100000):
        v += 1 * step
    print(f'{thread_name} -> ', v, flush=True)


v = 0
t1 = threading.Thread(target=count, args=('t1', 1, ))
t2 = threading.Thread(target=count, args=('t2', -1, ))
t1.start()
t2.start()
t1.join()
t2.join()

這段範例程式的執行結果,就跟本文要解說的 threading.local() 有關。

本文環境

  • Python 3

範例程式執行結果

以下這段範例程式,單從程式碼上來看,應該會 t1 執行緒將 v 從 0 加到 10,000, 然後 t2 則從 0 開始扣到 -10,000:

import threading


def count(thread_name, step=1):
    global v
    for i in range(0, 100000):
        v += 1 * step
    print(f'{thread_name} -> ', v, flush=True)


v = 0
t1 = threading.Thread(target=count, args=('t1', 1, ))
t2 = threading.Thread(target=count, args=('t2', -1, ))
t1.start()
t2.start()
t1.join()
t2.join()

不過實際上執行結果卻是(或類似結果):

t2 ->  -18918
t1 ->  0

這是由於執行緒 t1, t2 共用記憶體中的 v ,造成 2 個執行緒互相干擾,最終導致結果不正確。

threading.local()

如果是這種因為執行緒之間共用記憶體而造成相互干擾(interference)的問題,就需要用 threading.local() 解決, threading.local() 顧名思義就是執行緒自己的資料,或稱 Thread-local data, Thread-local storage(TLS), Thread-private 。

很多語言都有 thread-local data 的功能,顧名思義就是執行緒自己獨有的資料,不能給其他執行緒存取。

會需要這功能是由於執行緒之間會共用 process 的記憶體,資源的共享就會造成彼此干擾的問題,除了用 lock 等機制來控制資源的存取之外,也可以用 thread-local data 解決,也就是讓執行緒有自己獨享的資源,譬如資料庫連線如果被多個執行緒共享,那就有可能某一執行緒關掉連線,造成其他執行緒無法使用的窘境,這種情境就很適合使用 thread-local data 。

使用方法很簡單,只要把會被干擾的變數改為使用 threading.local() 初始化,並將值設定為其中一個屬性即可,例如:

thread_local = threading.local()
thread_local.my_value = 0

所以,前述有問題的 Python 程式碼,可以用 threading.local() 改為下列範例,以修正其錯誤:

import threading


t_local = threading.local()


def init_thread_local():
    global t_local
    t_local.v = 0


def count(thread_name, step=1):
    global t_local
    init_thread_local()
    for i in range(0, 100000):
        t_local.v += 1 * step
    print(f'{thread_name} -> ', t_local.v, flush=True)


t1 = threading.Thread(target=count, args=('t1', 1, ))
t2 = threading.Thread(target=count, args=('t2', -1, ))
t1.start()
t2.start()
t1.join()
t2.join()

上述範例執行結果如下 ,可以看到 t1, t2 執行緒各自運作正常,互不干擾:

t1 ->  100000
t2 ->  -100000

Threading.local() 怎麼運作的?

實際上, threading.local() 是由 Python 實作的功能,它的本質上也還是執行緒共用的記憶體,所有執行緒都可以存取到由 threading.local() 所建立的 instance, 只是這個 instance 內部用一個 dictionary 區隔不同執行緒的 thread-local data 。

詳細可以閱讀 threadling.local() 原始碼 ,這份原始碼有個私有 class _localimpl ,專門用來管理各個執行緒的 thread-local data, 所以可以看到它會依據 current_thread() 的不同,建立不同的 dictionary 用來儲存執行緒的私有資料。

class _localimpl:
    """A class managing thread-local dicts"""
    __slots__ = 'key', 'dicts', 'localargs', 'locallock', '__weakref__'

    def __init__(self):
        # The key used in the Thread objects' attribute dicts.
        # We keep it a string for speed but make it unlikely to clash with
        # a "real" attribute.
        self.key = '_threading_local._localimpl.' + str(id(self))
        # { id(Thread) -> (ref(Thread), thread-local dict) }
        self.dicts = {}

    def get_dict(self):
        """Return the dict for the current thread. Raises KeyError if none
        defined."""
        thread = current_thread()
        return self.dicts[id(thread)][1]

    def create_dict(self):
        """Create a new dict for the current thread, and return it."""
        localdict = {}
        key = self.key
        thread = current_thread()
        idt = id(thread)
        def local_deleted(_, key=key):
            # When the localimpl is deleted, remove the thread attribute.
            thread = wrthread()
            if thread is not None:
                del thread.__dict__[key]
        def thread_deleted(_, idt=idt):
            # When the thread is deleted, remove the local dict.
            # Note that this is suboptimal if the thread object gets
            # caught in a reference loop. We would like to be called
            # as soon as the OS-level thread ends instead.
            local = wrlocal()
            if local is not None:
                dct = local.dicts.pop(idt)
        wrlocal = ref(self, local_deleted)
        wrthread = ref(thread, thread_deleted)
        thread.__dict__[key] = wrlocal
        self.dicts[idt] = wrthread, localdict
        return localdict

簡單來說,這個管理各個 thread-local data 的 dictionary 內部結構長這樣:

{
    id(Thread): (ref(Thread), thread-local dict),
    ...
}

舉前述範例為例:

{
    (t1_thread_id): ( ref(t1_thread), {'v': 0} ),
    (t2_thread_id): ( ref(t2_thread), {'v': 0} ),
}

總結

threading.local() 雖然大多數時候用不到,但是它在 PySpark, TensorFlow, Keras, PyTorch 等框架都有用到,如果能夠理解它的原理,絕對能夠對於 Python 的 Concurrent Programming 功力有所提升。

Happy Coding!

References

threading — Thread-based parallelism

cpython/Lib/_threading_local.py at 3.11 · python/cpython

對抗久坐職業傷害

研究指出每天增加 2 小時坐著的時間,會增加大腸癌、心臟疾病、肺癌的風險,也造成肩頸、腰背疼痛等常見問題。

然而對抗這些問題,卻只需要工作時定期休息跟伸展身體即可!

你想輕鬆改變現狀嗎?試試看我們的 PomodoRoll 番茄鐘吧! PomodoRoll 番茄鐘會根據你所設定的專注時間,定期建議你 1 項辦公族適用的伸展運動,幫助你打敗久坐所帶來的傷害!

贊助我們的創作

看完這篇文章了嗎? 休息一下,喝杯咖啡吧!

如果你覺得 MyApollo 有讓你獲得實用的資訊,希望能看到更多的技術分享,邀請你贊助我們一杯咖啡,讓我們有更多的動力與精力繼續提供高品質的文章,感謝你的支持!