LangChain 怎麼玩?為了荷包著想,管好你的 prompt 長度(size)
Posted on Mar 5, 2024 in LangChain , Python 程式設計 - 高階 by Amo Chen ‐ 5 min read
每 1 個語言模型都有其限制,而所有語言模型都會有的限制就是 tokens 上限,而 tokens 上限會影響能夠接受的 prompt 長度。
如果你使用的不是開源語言模型,那不控制 prompt 長度所帶來的影響還有伴隨而來的費用。
以 OpenAI 所提供的語言模型 GPT-4 為例,每輸入 1 百萬個 tokens 就需要收費 $30 美元,如果你的應用沒有注意使用者所輸入 prompt 長度,那你很可能會為不必要的 tokens 付出代價。
所以學會如何控制 prompt 長度也是一門重要的課題!因為會幫助你避免踩到語言模型的限制之外,也可以幫助你控制費用支出。
本文環境
$ pip install langchain transformers
本文需要 Ollama 指令與 Meta 公司提供的 llama2
模型(model),ollama
指令請至 Ollama 官方頁面下載安裝,安裝完成之後即可執行指令安裝 llama2
模型,其安裝指令如下:
$ ollama run llama2
p.s. 執行上述模型需要至少 8GB 記憶體, 3.8G 硬碟空間
自然語言處理的 token 是什麼?
自然語言處理(NLP, Natural Language Processing)與機器學習(machine learning)的相關技術中,有 1 個稱為 Tokenization (或稱斷詞)。該技術是指將一連串的文本(text)拆分成多個部分的方法,這個方法的拆分標準可以用字元(characters)為單位,或者以單字(words)為單位,舉 “I love LangChain.” 為例,若以單字為拆分單位, tokenization 之後會變成 “I”, “love”, “LangChain” 3 個 tokens 。
Tokens 的應用很廣,包含現在常見的搜尋引擎,基本原理就是將使用者輸入的關鍵字拆成多個 tokens 之後,再到資料庫搜尋哪些網頁/文件含有這些 tokens, 當然這些網頁/文件在存進資料庫之前,也會進行所謂的 tokenization ,將 token 與文件編號做關聯,所以查詢到匹配的關鍵字時,就能順便知道哪些文件與這個關鍵字相關。
例如我們查詢關鍵字 LangChain
時(LangChain
是 token),從資料庫關聯可以找到 LangChain
關聯到文件 2, 3, 8, 10 這 4 個文件,所以就可以按照順序顯示文件標題與摘要給使用者:
"token", "文件編號"
"LangChain", "2,3,8,10"
"Python", "2,9,11"
Tokenization 同樣地也運用在語言模型(Language Model)相關的應用中,例如將使用者輸入的文字經過 tokenization 之後,就可以進一步轉為 embedding, 再找到 embedding 相似的文件,最後再交由語言模型產生回應。
p.s. 想知道 GPT-3.5, GPT-4 等如何做 tokenization 的話,可以使用 Tiktokenizer 測試
語言模型能接受的 Tokens 數量有限!
現在語言模型能夠接受的 Tokens 數量有上限,例如:
- GPT-4-Turbo 上限 128,000
- GPT-3.5-Turbo 上限 16,385
- Llama2 上限 4,096
值得注意的是 token 的拆分單位不一定相同, OpenAI 系列的模型大概是 4 個 English 字元為 1 個 token, 或者 3/4 的單字長度作為 1 個 token, 沒有絕對的計算方式,在使用上要注意有 tokenization 的單位區別。
也由於 Tokens 的數量設有上限,所以在處理語言模型相關應用時,特別要注意是否會超過語言模型能處理的 token 上限。
此外, OpenAI 語言模型的計算計算也以 tokens 的輸入與輸出作為基準,舉 GPT-4 為例,每輸入 1 百萬個 tokens 就需要收費 $30 美元,每輸出 1 百萬個 tokens 就需要收費 $60 美元,所以對於使用 OpenAI 語言模型提供服務的應用來說,藉控制 prompt 的長度以控制營運支出,是相當重要的一件事。
LangChain 取得 tokens 長度
以下是取得 tokens 長度的最簡單範例,有 2 種方法:
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage
llm = Ollama(model='llama2')
num = llm.get_num_tokens_from_messages([HumanMessage('Hi, there. How are you today?')])
print(f'Tokens: {num}')
num = llm.get_num_tokens('Hi, there. How are you today?')
print(f'Tokens: {num}')
上述範例執行結果如下:
Tokens: 11
Tokens: 9
可以看到使用 llama2
對 Hi, there. How are you today?
字串進行 tokenization 後,得到 token size 為 9, 對 HumanMessage('Hi, there. How are you today?')
做 tokenization 之後為 11 。
為什麼會有差異呢?
其實是因為 HumanMessage('Hi, there. How are you today?')
最後被加工為 Human: Hi, there. How are you today?
的緣故,多了 Human:
導致 tokens 長度不同。
得到 tokens 長度的方法有 2 種。
一為呼叫 <model>.get_num_tokens_from_messages()
(文件),一為呼叫 <model>.llm.get_num_tokens()
,差別在於 <model>.llm.get_num_tokens()
接受純字串。
如果是使用 OpenAI 的語言模型也有相同的方法可以呼叫:
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo")
num = llm.get_num_tokens_from_messages([HumanMessage('Hi, there. How are you today?')])
print(f'Tokens: {num}')
num = llm.get_num_tokens('Hi, there. How are you today?')
print(f'Tokens: {num}')
上述執行結果如下:
Tokens: 16
Tokens: 9
LangChain 控制 prompt 長度的範例
知道如何取得 tokens 長度之後,我們就可以進一步在 chain 中對 prompt 長度做控管!
做法也很簡單,就是在 prompt 送進語言模型之前做攔截、修改長度:
prompt | <修改 prompt> | 語言模型
而 prompt template 在結合 input values 之後,會變成 ChatPromptValue 的實例,該實例(instance)有個 to_messages()
方法可以呼叫,我們只要把 to_messages()
回傳的結果,放進 <model>.get_num_tokens_from_messages()
得到 prompt 的長度之後,如果 prompt 長度超過限制,就丟掉舊的 messages, 再回傳新的 ChatPromptValue 實例即可,如下列範例中的 condense_prompt()
函式:
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.messages import HumanMessage, AIMessage
chat_history = []
def condense_prompt(prompt: ChatPromptValue) -> ChatPromptValue:
messages = prompt.to_messages()
num_tokens = llm.get_num_tokens_from_messages(messages)
recent_messages = messages[2:]
while num_tokens > 500:
recent_messages = recent_messages[2:]
num_tokens = llm.get_num_tokens_from_messages(
messages[:2] + recent_messages
)
# also update chat history
chat_history = recent_messages
messages = messages[:2] + recent_messages
return ChatPromptValue(messages=messages)
llm = Ollama(model='llama2')
prompt = ChatPromptTemplate.from_messages([
('system', 'You are a powerful chat bot.'),
MessagesPlaceholder(variable_name='chat_history'),
("user", "{input}"),
])
chain = prompt | condense_prompt | llm
input_text = input('>>> ')
while input_text.lower() != 'bye':
if input_text:
response = chain.invoke({
'input': input_text,
'chat_history': chat_history,
})
chat_history.append(HumanMessage(content=input_text))
chat_history.append(AIMessage(content=response))
print(response)
input_text = input('>>> ')
上述範例的重點在於 condense_prompt()
函式會對傳進來的 ChatPromptValue
檢查 tokens 總長度,如果超過 500, 就只保留第 1, 2 個 messages, 接著再從剩下的 messages 扣掉前 2 個,再檢查是否剩下的 messages 總長度是否超過 500 ,如果沒有超過 500 的話,就合併最早的 2 個訊息(內含 system message),並回傳新的 ChatPromptValue
,如果還是超過 500 就再繼續扣掉 2 個 messages, 之所以都是以 2 為單位扣除,是因為一次可以扣掉 1 個 HumanMessage
與 1 個 AIMessage
,此處的邏輯完全可以按照需求修改。
噢對,由於我們會留下對話紀錄(chat history),所以也需要在 condense_prompt()
中一併更新對話紀錄 chat_history
, 避免被丟掉的對話紀錄又重新回到 prompt 中,看到此處大家應該更能理解為什麼跟 ChatGPT 聊著聊著 ChatGPT 就失憶的原因囉,是因為踩到 tokens 上限,只能將對話紀錄做縮減的緣故。( p.s. 此處還可以有更好的寫法就留給各位摸索囉 )
以上就是如何修改 prompt 長度避免超過 tokens 上限的做法。
總結
基於語言模型有 tokens 上限與 tokens 數量可能是語言模型計費項目這 2 個理由,使得控制 prompt 長度是實作語言模型相關應用(application)需要注意的一項問題,但如果不是將語言模型打造成聊天應用的話,這個問題就會相對小一些。
以上!
Enjoy!
References
What are tokens and how to count them? | OpenAI Help Center