Python contextvars 模組教學

Posted on  Sep 21, 2023  in  Python 程式設計 - 高階  by  Amo Chen  ‐ 4 min read

大家都知道執行緒(Threads)之間會共用 Process 的記憶體,這種共用的情況有可能造成 Race Condition, 使得程式出現不可預期的行為或錯誤,所幸這個問題可以透過 threadling.local 解決。

而 Python 3.4 之後推出 asyncio 模組,使得 Python 具備執行非同步 I/O (asynchronous I/O) 的能力,開發者可以同時結合 multiprocessing , threading 以及 asyncio 將 Python 效能提升至全新檔次,但是當結合 threading 以及 asyncio 時,可能會遭遇一個問題, 多個協程(coroutines)可能會在同 1 個執行緒中執行,因此多個 coroutines 也可能有互相影響的情況!所以有了 contextvars 模組,用以解決 coroutines 互相干擾的情況,將每個 coroutine 以 Context 切開,避免互相干擾!

Thread-local variables are insufficient for asynchronous tasks that execute concurrently in the same OS thread. Any context manager that saves and restores a context value using threading.local() will have its context values bleed to other code unexpectedly when used in async/await code.

本文環境

  • Python 3.8 以上

Coroutines 互相干擾的範例

下列是 1 個結合 threading 與 asyncio 的範例程式,這個範例只是用 2 個執行緒搭配 asyncio 分別讓 5 個 asyncio Task 分別對執行緒的私有資料做 10 次的 +1 運算:

import asyncio
import threading


t_local = threading.local()


async def _async_count(thread_name):
    await asyncio.sleep(1)
    global t_local
    for _ in range(0, 10):
        t_local.v += 1
    print(
        'thread_name:', thread_name,
        '|',
        't_local.v:', t_local.v,
    )


async def async_count(thread_name):
    global t_local
    t_local.v = 0
    tasks = [
        _async_count(thread_name) for _ in range(0, 5)
    ]
    await asyncio.gather(*tasks)
    print(f'{thread_name} -> ', t_local.v, flush=True)


def count(thread_name):
    asyncio.run(async_count(thread_name))


t1 = threading.Thread(target=count, args=('t1', ))
t2 = threading.Thread(target=count, args=('t2', ))
t1.start()
t2.start()
t1.join()
t2.join()

上述範例執行結果如下,雖然透過 threading.local() 使得每個執行緒有各自的私有空間,不過執行緒內的 coroutines 仍然可以存取同 1 個執行緒內的 t_local 私有空間,所以最後可以看到 2 個執行緒各自做了 50 次的 +1 運算,得到結果為 50:

thread_name: t1 | t_local.v: 10
thread_name: t1 | t_local.v: 20
thread_name: t1 | t_local.v: 30
thread_name: t1 | t_local.v: 40
thread_name: t1 | t_local.v: 50
thread_name: t2 | t_local.v: 10
thread_name: t2 | t_local.v: 20
thread_name: t2 | t_local.v: 30
thread_name: t2 | t_local.v: 40
thread_name: t2 | t_local.v: 50
t1 ->  50
t2 ->  50

ContextVar

如果想讓 coroutines 之間不互相干擾,除了在 coroutine 內定義私有變數之外,也可以透過 contextvars 讓 coroutines 有獨立的 Context, 讀取寫入到各自的 Context 內,所以上述範例如果要讓 coroutines 擁有獨立的 Context, 可以將 t_local 改為 ContextVar :

import asyncio
import threading
from contextvars import ContextVar


var = ContextVar('var', default=0)


async def _async_count(thread_name):
    await asyncio.sleep(1)
    for _ in range(0, 10):
        var.set(var.get() + 1)
    print(
        'thread_name:', thread_name,
        '|',
        'var:', var.get(),
    )


async def async_count(thread_name):
    tasks = [
        _async_count(thread_name) for _ in range(0, 5)
    ]
    await asyncio.gather(*tasks)
    print(f'{thread_name} -> ', var.get(), flush=True)


def count(thread_name):
    asyncio.run(async_count(thread_name))


t1 = threading.Thread(target=count, args=('t1', ))
t2 = threading.Thread(target=count, args=('t2', ))
t1.start()
t2.start()
t1.join()
t2.join()

上述範例執行結果如下,可以看到每個協程都沒有互相干擾,因此 count 都只有到 10, 而最神奇的是在 async_count() 得到的結果為 0:

thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t2 | var: 10
thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
t1 ->  0
t2 ->  0

這是因為 async_count() 內的 var 也是 1 個獨立的 Context, 因此不會受到 _async_count() 的影響!

如果把 var 改為傳參數傳入 _async_count 也一樣,無法作用!

import asyncio
import threading
from contextvars import ContextVar


var = ContextVar('var', default=0)


async def _async_count(thread_name, var):
    await asyncio.sleep(1)
    for _ in range(0, 10):
        var.set(var.get() + 1)
    print(
        'thread_name:', thread_name,
        '|',
        'var:', var.get(),
    )


async def async_count(thread_name):
    tasks = [
        _async_count(thread_name, var) for _ in range(0, 5)
    ]
    await asyncio.gather(*tasks)
    print(f'{thread_name} -> ', var.get(), flush=True)


def count(thread_name):
    asyncio.run(async_count(thread_name))


t1 = threading.Thread(target=count, args=('t1', ))
t2 = threading.Thread(target=count, args=('t2', ))
t1.start()
t2.start()
t1.join()
t2.join()

上述範例執行結果如下所示,可以看到就算將 varasync_count() 傳到 _async_count() 也沒有改變結果,證明 ContextVar 是以 coroutine 為單位的私有空間:

thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t2 | var: 10
thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t1 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
thread_name: t2 | var: 10
t1 ->  0
t2 ->  0

ContextVar 可以取代 threading.local()

是的, ContextVar 可以取代 threading.local(), 下列範例完美示範 2 個執行緒分別對 var 做 +1 運算 10 次,結果不會互相影響:

import asyncio
import threading
from contextvars import ContextVar


var = ContextVar('var', default=0)


def count(thread_name):
    for _ in range(0, 10):
        var.set(var.get() + 1)
    print(f'{thread_name} ->', var.get())


t1 = threading.Thread(target=count, args=('t1', ))
t2 = threading.Thread(target=count, args=('t2', ))
t1.start()
t2.start()
t1.join()
t2.join()

上述範例執行結果如下,可以發現 ContextVar 可以取代 threading.local():

t1 -> 10
t2 -> 10

雖然 ContextVar 可以取代 threading.local() , 但是要小心 ContextVar 是以 coroutine 為單位獨立的特性,如果牽扯到 coroutine 的話,就會有非預期的結果發生!

認識 ContextVar 相關方法

get(), set()

ContextVar 的使用相當簡單,只需要認識 get()set() 方法即可,例如下列範例:

from contextvars import ContextVar

var = ContextVar('var')

def main():
    print('var:', var.get())


var.set('Hello')
main()

var.set('World')
main()

上述範例執行結果如下:

var: Hello
var: World

如果要為 ContextVar 設定預設值,可以加上 default 參數:

var = ContextVar('var', default=0)

reset()

呼叫 set() 方法時,該方法會回傳 Token 實例,該實例會存呼叫 set() 之前的舊值在 old_value 屬性中,這個 token 就可以用來代入呼叫 reset() 方法, reset() 方法會將 ContextVar 的值設定回 token 裡的 old_value, 如下列範例:

from contextvars import ContextVar

var = ContextVar('var')

def main():
    print('var:', var.get())


var.set('Hello')
token = var.set('World')
main()
print('Will call reset() to reset var to:', token.old_value)
var.reset(token)
main()

上述範例執行結果如下,可以看到我們順利將 ContextVar 恢復為舊值 Hello :

var: World
Will call reset() to reset var to: Hello
var: Hello

ContextVar 的進一步應用

如果你有些函數需要一層一層傳遞參數到很底層,例如:

def func_lv2(s):
    func_lv3(s)


def func_lv3(s):
    print('lv3 got', s)


def func_lv1():
    s = 1
    func_lv2(s)


func_lv1()

這時可以用 ContextVar 直接略過中間層,直接讓最底層取得:

from contextvars import ContextVar


var = ContextVar('var')


def func_lv2():
    func_lv3()


def func_lv3():
    print('lv3 got', var.get())


def func_lv1():
    var.set(1)
    func_lv2()


func_lv1()

或者你可以像 Flask 一樣,把 request object 變成 ContextVar, 搭配 LocalProxy ,讓任何 request handler 都可以藉由 import 的方式取得 request object:

from flask import request

When a Flask application begins handling a request, it pushes a request context, which also pushes an app context. When the request ends it pops the request context then the application context. … Context locals are implemented using Python’s contextvars and Werkzeug’s LocalProxy. Python manages the lifetime of context vars automatically, and local proxy wraps that low-level interface to make the data easier to work with.

這些用法同樣要注意 ContextVar 是以 coroutine 為單位獨立的特性,如果這中間有任何 1 層是使用 coroutine, 那可能結果就可能會不一樣!

總結

contextvars 是 Python 3.7 之後才推出的模組,較少人談到它的用途,不過它確實已經被運用到 Flask, Tornado 等框架之中,足見其重要性。

以上是關於 contextvars 模組的簡介,更詳細的內容請參閱官方文件

Happy Coding!

References

contextvars — Context Variables

PEP 567 – Context Variables

對抗久坐職業傷害

研究指出每天增加 2 小時坐著的時間,會增加大腸癌、心臟疾病、肺癌的風險,也造成肩頸、腰背疼痛等常見問題。

然而對抗這些問題,卻只需要工作時定期休息跟伸展身體即可!

你想輕鬆改變現狀嗎?試試看我們的 PomodoRoll 番茄鐘吧! PomodoRoll 番茄鐘會根據你所設定的專注時間,定期建議你 1 項辦公族適用的伸展運動,幫助你打敗久坐所帶來的傷害!

贊助我們的創作

看完這篇文章了嗎? 休息一下,喝杯咖啡吧!

如果你覺得 MyApollo 有讓你獲得實用的資訊,希望能看到更多的技術分享,邀請你贊助我們一杯咖啡,讓我們有更多的動力與精力繼續提供高品質的文章,感謝你的支持!