-
Notifications
You must be signed in to change notification settings - Fork 1k
/
bot.py
289 lines (248 loc) · 8.9 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
import sys
INFERENCE_DIR = os.path.dirname(os.path.abspath(__file__))
# TODO: PYTHONPATH hacks are never a good idea. clean this up later
sys.path.append(os.path.join(INFERENCE_DIR, '..'))
import cmd
import torch
import argparse
import conversation as convo
import retrieval.wikipedia as wp
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
from accelerate import infer_auto_device_map, init_empty_weights
class StopWordsCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words, stream_callback):
self._tokenizer = tokenizer
self._stop_words = stop_words
self._partial_result = ''
self._stream_buffer = ''
self._stream_callback = stream_callback
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
first = not self._partial_result
text = self._tokenizer.decode(input_ids[0, -1])
self._partial_result += text
for stop_word in self._stop_words:
if stop_word in self._partial_result:
return True
if self._stream_callback:
if first:
text = text.lstrip()
# buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
for stop_word in self._stop_words:
for i in range(1, len(stop_word)):
if self._partial_result.endswith(stop_word[0:i]):
self._stream_buffer += text
return False
self._stream_callback(self._stream_buffer + text)
self._stream_buffer = ''
return False
class ChatModel:
human_id = "<human>"
bot_id = "<bot>"
def __init__(self, model_name, gpu_id, max_memory):
device = torch.device('cuda', gpu_id) # TODO: allow sending to cpu
# recommended default for devices with > 40 GB VRAM
# load model onto one device
if max_memory is None:
self._model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto")
self._model.to(device)
# load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
else:
config = AutoConfig.from_pretrained(model_name)
# load empty weights
with init_empty_weights():
model_from_conf = AutoModelForCausalLM.from_config(config)
model_from_conf.tie_weights()
# create a device_map from max_memory
device_map = infer_auto_device_map(
model_from_conf,
max_memory=max_memory,
no_split_module_classes=["GPTNeoXLayer"],
dtype="float16"
)
# load the model with the above device_map
self._model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
offload_folder="offload", # optional offload-to-disk overflow directory (auto-created)
offload_state_dict=True,
torch_dtype=torch.float16
)
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):
stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback)
inputs = (
self._tokenizer(prompt, return_tensors='pt')
.to(self._model.device)
)
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
pad_token_id=self._tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
)
output = self._tokenizer.batch_decode(outputs)[0]
# remove the context from the output
output = output[len(prompt):]
return output
class OpenChatKitShell(cmd.Cmd):
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
prompt = ">>> "
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
super().__init__()
self._gpu_id = gpu_id
self._model_name_or_path = model_name_or_path
self._max_tokens = max_tokens
self._sample = sample
self._temperature = temperature
self._top_k = top_k
self._retrieval = retrieval
self._max_memory = max_memory
self._do_stream = do_stream
def preloop(self):
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory)
if self._retrieval:
print(f"Loading retrieval index...")
self._index = wp.WikipediaIndex()
self._convo = convo.Conversation(
self._model.human_id, self._model.bot_id)
def precmd(self, line):
if line.startswith('/'):
return line[1:]
else:
return 'say ' + line
def do_say(self, arg):
if self._retrieval:
results = self._index.search(arg)
if len(results) > 0:
self._convo.push_context_turn(results[0])
self._convo.push_human_turn(arg)
output = self._model.do_inference(
self._convo.get_raw_prompt(),
self._max_tokens,
self._sample,
self._temperature,
self._top_k,
lambda x : print(x, end='', flush=True) if self._do_stream else None,
)
self._convo.push_model_response(output)
print("" if self._do_stream else self._convo.get_last_turn())
def do_raw_say(self, arg):
output = self._model.do_inference(
arg,
self._max_tokens,
self._sample,
self._temperature,
self._top_k
)
print(output)
def do_raw_prompt(self, arg):
print(self._convo.get_raw_prompt())
def do_reset(self, arg):
self._convo = convo.Conversation(
self._model.human_id, self._model.bot_id)
def do_hyperparameters(self, arg):
print(
f"Hyperparameters:\n"
f" max_tokens: {self._max_tokens}\n"
f" sample: {self._sample}\n"
f" temperature: {self._temperature}\n"
f" top_k: {self._top_k}"
)
def do_quit(self, arg):
return True
def main():
parser = argparse.ArgumentParser(
description='test harness for OpenChatKit')
parser.add_argument(
'--gpu-id',
default=0,
type=int,
help='the ID of the GPU to run on'
)
parser.add_argument(
'--model',
default=f"{INFERENCE_DIR}/../huggingface_models/Pythia-Chat-Base-7B",
help='name/path of the model'
)
parser.add_argument(
'--max-tokens',
default=128,
type=int,
help='the maximum number of tokens to generate'
)
parser.add_argument(
'--sample',
default=True,
action='store_true',
help='indicates whether to sample'
)
parser.add_argument(
'--no-stream',
action='store_true',
help='indicates whether to stream tokens'
)
parser.add_argument(
'--temperature',
default=0.6,
type=float,
help='temperature for the LM'
)
parser.add_argument(
'--top-k',
default=40,
type=int,
help='top-k for the LM'
)
parser.add_argument(
'--retrieval',
default=False,
action='store_true',
help='augment queries with context from the retrieval index'
)
parser.add_argument(
'-g',
'--gpu-vram',
action='store',
help='max VRAM to allocate per GPU',
nargs='+',
required=False,
)
parser.add_argument(
'-r',
'--cpu-ram',
default=None,
type=int,
help='max CPU RAM to allocate',
required=False
)
args = parser.parse_args()
# set max_memory dictionary if given
if args.gpu_vram is None:
max_memory = None
else:
max_memory = {}
for i in range(len(args.gpu_vram)):
# assign CUDA ID as label and XGiB as value
max_memory[int(args.gpu_vram[i].split(':')[0])] = f"{args.gpu_vram[i].split(':')[1]}GiB"
if args.cpu_ram is not None:
# add cpu to max-memory if given
max_memory['cpu'] = f"{int(args.cpu_ram)}GiB"
OpenChatKitShell(
args.gpu_id,
args.model,
args.max_tokens,
args.sample,
args.temperature,
args.top_k,
args.retrieval,
max_memory,
not args.no_stream,
).cmdloop()
if __name__ == '__main__':
main()