Finetuning Llama-7B/13B or MosaicML MPT-7B - Reproduce Vicuna / Alpaca

EDIT May 12: I am posting extra info in the thread to finetune MPT-7B.

EDIT May 23: thanks to @l-k-11235 we have now a step-by-step tuto with a gradio example
Link in the thread.

EDIT June 2: LoRA layers can be quantized, all Linear layers quantizable in 4bit - 13B finetuned smoothly

Hello Community,

We can now finetune the 7B/13B llama model and reproduce Vicuna / Alpaca.
This is due to the new LoRa capability and the 4/8bit loading (with Bitsandbytes).

Remember, llama 7B is a decoder only tranformer with 32 layers, 32 heads, model dim 4096 and ffn 11008. This means that the self attention modules (Q, K, V, O) take 4 x (4096 x 4096) x 32 = 2.1e9 parameters and the 3 positionwise feed-forward modules (w_1, w_2, w_3) take 3 x (4096 x 11008) x 32 = 4.3e9 parameters. The rest is negligeable wrt those two key elements.

In order to finetune llama7b we will:

  • use LoRa for the self-attention modules to reduce massively the trainable parameters.
  • use 4/8bit loading for w_X modules to reduce massively the memory footprint of the model in the GPU VRAM. Edit: self-attention modules can be quantized as well within their LoRA status.

So the yaml config file will include this section:

#4/8bit
quant_layers: ['w_1', 'w_2', 'w_3', 'linear_values', 'linear_query', 'linear_keys', 'final_linear']
quant_type: "bnb_NF4"

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 8
lora_dropout: 0.05
lora_alpha: 16
lora_embedding: false

# Chekpointing
#use_ckpting: ['ffn', 'lora']

Gradient checkpointing is also available but brings little memory saving (all depend on how close to your limit you are)

Also the llama model uses various specific features:
RMSNorm for layer normalization
SILU activation
Rotary embeddings
Hence the model section of yaml config file needs to be as follow:

# Model
model_task: lm
decoder_type: transformer_lm
layer_norm: rms
pos_ffn_activation_fn: 'silu'
max_relative_positions: -1
position_encoding: false
add_qkvbias: false
dec_layers: 32
heads: 32
hidden_size: 4096
word_vec_size: 4096
transformer_ff: 11008
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]

Now we need datasets to finetune the model.
For those who had a look at the various implementation of Alpaca or Vicuna, you saw that they use JSON files containing the instructions / responses to finetune. One specific of those json files is that in a sentence it can contain ā€œ\nā€ (new line) which a sentence break in OpenNMT world.

Hence we need to use adhoc formatted datasets. We flattened all json files into plain text and the ā€œ\nā€ have been replaced with a specific token ((newline)) (the same we use for doc level training).

We tweaked the tokenizer transform to magically replace ((newline)) => ā€œ\nā€ so that we could still use the llama legacy sentencepiece tokenizer model.

The config looks like then:

data:
    alpaca:
        path_src: "/dataAI/alpaca_clean.txt"
        transforms: [sentencepiece, filtertoolong]
        weight: 10
    sharegpt:
        path_src: "/dataAI/sharegpt.txt"
        transforms: [sentencepiece, filtertoolong]
        weight: 10

#### Subword
src_subword_model: "/dataAI/tokenizer.model"
tgt_subword_model: "/dataAI/tokenizer.model"

#### Filter
src_seq_length: 512
tgt_seq_length: 512

# silently ignore empty lines in the data
skip_empty_level: silent

# General opts
train_from: "/7B/llama7B-onmt.pt"
save_model: "/dataAI/llama7B-vicuna-onmt"

The two .txt files can be downloaded from here:
https://opennmt-models.s3.amazonaws.com/llama/alpaca_clean.txt
https://opennmt-models.s3.amazonaws.com/llama/sharegpt.txt
The tokenizer.model is the legacy one from Llama.
llama7B-onmt-pt is the result of the conversion using convert_llama.py (in tools)

Once the training is finished, you need to merge the llama7B-vicuna-onmt.pt file (LoRa weights) in the original llama7B-onmt.pt model, using the lora_weights.py tool.
You can merge with two actions:
ā€˜mergeā€™ will create the same model as the original with modified weights, no Lora info, no optimizer
ā€˜concatā€™ will add Lora info along with the optimizer state if you need to continue training.

Results are very good even after a few hundred steps. Remember, instructions must be in the same format (including prompts) as in the training data.

Example 1:
Input in a oneline input.txt file

Below is an instruction that describes a task. Write a response that appropriately completes the request.ļ½Ÿnewlineļ½ ļ½Ÿnewlineļ½ ### Instruction:ļ½Ÿnewlineļ½ Write a Python program that prints the first 10 Fibonacci numbers.ļ½Ÿnewlineļ½ ļ½Ÿnewlineļ½ ### Response:ļ½Ÿnewlineļ½ 

Output:

Here is a Python program that prints the first 10 Fibonacci numbers:

def fibonacci(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n - 1) + fibonacci(n - 2)

print(fibonacci(10))

This program takes a number n as an input and prints the first 10 Fibonacci numbers. The function fibonacci() returns the nth Fibonacci number. If n is 0 or 1, the function returns 0 and 1 respectively. Otherwise, it calls itself with n - 1 and n - 2, and returns the sum of the two results.

The program prints the first 10 Fibonacci numbers as follows:

0
1
1
2
3
5
8
13
21
34
55

Example 2:
Input

Below is an instruction that describes a task. Write a response that appropriately completes the request.ļ½Ÿnewlineļ½ ļ½Ÿnewlineļ½ ### Instruction:ļ½Ÿnewlineļ½ Tell me about alpacas. ļ½Ÿnewlineļ½ ļ½Ÿnewlineļ½ ### Response:ļ½Ÿnewlineļ½ 

Output:

Alpacas are members of the camelid family, which also includes llamas and camels. They are native to South America, where they have been domesticated for thousands of years. Alpacas are known for their soft, luxurious fiber, which is used to make clothing, blankets, and other textiles. They are also raised for their meat, which is similar to lamb. Alpacas are herd animals and are social animals, living in groups of 10-20 animals. They are generally calm and docile, and are known for their intelligence and curiosity. Alpacas are easy to care for and require little maintenance, making them a popular choice for small farms and hobbyists.

The command line usage is a bit cumbersome (we need to add straming with a Callback())

python3 ~/nlp/OpenNMT-py/translate.py --config llama-inference.yaml -src input.txt -output output.txt && sed 's/ļ½Ÿnewlineļ½ /\n/g' output.txt

But with a nice gradio interface, you can have the same fun as all those fastChat, Alpaca-Lora, etc implementations.

3 Likes

Where to get vocabulary for llama?

here sorry
https://opennmt-models.s3.amazonaws.com/llama/vocab-llama.txt

and need to set
share_vocab: true
src_vocab: vocab-llama.txt
src_vocab_size: 32000
tgt_vocab_size: 32000

Thanks!
But not I get this error:

File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/model_builder.py", line 402, in build_base_model
    model.load_state_dict(checkpoint['model'],
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LanguageModel:
        Missing key(s) in state_dict: "decoder.layer_norm.bias". 

are you on master ?

show me the log when it prints the model, at the end like this:

  (31): TransformerLMDecoderLayer(
    (self_attn): MultiHeadedAttention(
      (linear_keys): Linear(
        in_features=4096, out_features=4096, bias=False
        (lora_dropout): Dropout(p=0.05, inplace=False)
      )
      (linear_values): Linear(
        in_features=4096, out_features=4096, bias=False
        (lora_dropout): Dropout(p=0.05, inplace=False)
      )
      (linear_query): Linear(
        in_features=4096, out_features=4096, bias=False
        (lora_dropout): Dropout(p=0.05, inplace=False)
      )
      (softmax): Softmax(dim=-1)
      (dropout): Dropout(p=0.0, inplace=False)
      (final_linear): Linear(
        in_features=4096, out_features=4096, bias=False
        (lora_dropout): Dropout(p=0.05, inplace=False)
      )
    )
    (feed_forward): PositionwiseFeedForward(
      (w_1): Linear8bitLt(in_features=4096, out_features=11008, bias=True)
      (w_2): Linear8bitLt(in_features=11008, out_features=4096, bias=True)
      (layer_norm): RMSNorm()
      (dropout_1): Dropout(p=0.0, inplace=False)
      (dropout_2): Dropout(p=0.0, inplace=False)
      (w_3): Linear8bitLt(in_features=4096, out_features=11008, bias=True)
    )
    (layer_norm_1): RMSNorm()
    (drop): Dropout(p=0.0, inplace=False)
  )
)

)
(generator): Linear(in_features=4096, out_features=32000, bias=True)
)
[2023-05-07 21:12:40,445 INFO] encoder: 0
[2023-05-07 21:12:40,445 INFO] decoder: 6616571904
[2023-05-07 21:12:40,445 INFO] * number of parameters: 6616571904
[2023-05-07 21:12:40,445 INFO] * src vocab size = 32000
[2023-05-07 21:12:40,445 INFO] * tgt vocab size = 32000

Yes, I am on master.
update_vocab: True helped with error, but I testing it on 24Gb card and got oom. As I understand, lora 8bit quantization is not ready yet? Will try on 48Gb card soon.

it is ready and working on 24GB, thatsā€™ what I am doing.
you can see in my log
w_1/w_2 are quantized
final_linear is Lora replaced.

maybe reduce your batchsize / seq_len.

post your config here.

Yes, I see. In my logs not quantization doneā€¦

[2023-05-08 08:47:24,799 INFO] LanguageModel(
  (decoder): TransformerLMDecoder(
    (embeddings): Embeddings(
      (make_embedding): Sequential(
        (emb_luts): Elementwise(
          (0): Embedding(32000, 4096, padding_idx=3)
        )
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (layer_norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=True)
    (transformer_layers): ModuleList(
      (0-31): 32 x TransformerLMDecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linear_keys): Linear(in_features=4096, out_features=4096, bias=False)
          (linear_values): Linear(in_features=4096, out_features=4096, bias=False)
          (linear_query): Linear(in_features=4096, out_features=4096, bias=False)
          (softmax): Softmax(dim=-1)
          (dropout): Dropout(p=0.0, inplace=False)
          (final_linear): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=4096, out_features=11008, bias=True)
          (w_2): Linear(in_features=11008, out_features=4096, bias=True)
          (layer_norm): RMSNorm()
          (dropout_1): Dropout(p=0.0, inplace=False)
          (dropout_2): Dropout(p=0.0, inplace=False)
          (w_3): Linear(in_features=4096, out_features=11008, bias=True)
        )
        (layer_norm_1): RMSNorm()
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (generator): Linear(in_features=4096, out_features=32000, bias=True)
)
[2023-05-08 08:47:24,802 INFO] encoder: 0
[2023-05-08 08:47:24,802 INFO] decoder: 6608183296
[2023-05-08 08:47:24,803 INFO] * number of parameters: 6608183296
[2023-05-08 08:47:24,803 INFO]  * src vocab size = 32000
[2023-05-08 08:47:24,803 INFO]  * tgt vocab size = 32000

With this config I get Missing key(s) in state_dict: "decoder.layer_norm.bias"

#LoRa
lora_layers: ['linear_values', 'linear_query']
lora_rank: 8
lora_dropout: 0.05
lora_alpha: 16
lora_embedding: false

#8bit
quant_layers: ['w_1', 'w_2', 'w_3']

data:
  corpus:
    path_src: "../dataset/llama.train.txt"
    transforms: [sentencepiece, filtertoolong]
    weight: 10
  valid:
    path_src: "../dataset/llama.test.txt"
    transforms: [sentencepiece, filtertoolong]

#### Subword
src_subword_model: "llama-tokenizer.model"
tgt_subword_model: "llama-tokenizer.model"
share_vocab: true
src_vocab: vocab-llama.txt
#update_vocab: true
src_vocab_size: 32000
tgt_vocab_size: 32000

#### Filter
src_seq_length: 512
tgt_seq_length: 512

# silently ignore empty lines in the data
skip_empty_level: silent

# General opts
train_from: "llama7b.pt"
save_model: "ready/llama7b-main.pt"

#reset_optim: all
keep_checkpoint: 1
save_checkpoint_steps: 100
#average_decay: 0.0005
seed: 1234

report_every: 10

world_size: 1
gpu_ranks: [0]
optim: "fusedadam"
learning_rate: 0.2
warmup_steps: 100

# Model
model_task: lm
decoder_type: transformer_lm
layer_norm: rms
pos_ffn_activation_fn: 'silu'
max_relative_positions: -1
position_encoding: false
add_qkvbias: False
dec_layers: 32
heads: 32
hidden_size: 4096
word_vec_size: 4096
transformer_ff: 11008
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]

If I add reset_optim: all and update_vocab: True, I get oom and looks like no lora, no quantization in model.
I am on master branch. onmt_train --help shows it accept lora parameters.

add this:
override_opts: true

but do not set reset_optim / update_vocab

Thanks. Now the training starts. But canā€™t beat Grad overflow on iteration error. Maybe you can share full working config?

Currently training with open_llama with this config:

# Corpus opts:
data:
    alpaca:
        path_src: "alpaca_clean.txt"
        transforms: [sentencepiece, filtertoolong]
        weight: 10
    sharegpt:
        path_src: "sharegpt.txt"
        transforms: [sentencepiece, filtertoolong]
        weight: 10

    valid:
        path_src: "valid.txt"
        transforms: [sentencepiece]

### Transform related opts:
#### Subword
src_subword_model: "openllama/tokenizer.model"
tgt_subword_model: "openllama/tokenizer.model"

#### Filter
src_seq_length: 512
tgt_seq_length: 512

#truncated_decoder: 32

# silently ignore empty lines in the data
skip_empty_level: silent

# General opts
train_from: "openllama/openllama-onmt.pt"
save_model: "openllama/open-vicuna-diff-onmt"
keep_checkpoint: 10
save_checkpoint_steps: 400
seed: 1234
report_every: 10
train_steps: 4000
valid_steps: 400

# Batching
bucket_size: 32768
#bucket_size: 1
num_workers: 2
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"
batch_size: 896
valid_batch_size: 256
batch_size_multiple: 1
accum_count: [32]
accum_steps: [0]

override_opts: true 

share_vocab: true
save_data: "/openllama"
src_vocab: "openllama/openllama-vocab.txt"
src_vocab_size: 32000
tgt_vocab_size: 32000

decoder_start_token: '<s>'
# Optimization
model_dtype: "fp16"
apex_opt_level: ""
optim: "fusedadam"
learning_rate: 0.00002
warmup_steps: 100
decay_method: "none"
#learning_rate_decay: 0.98
#start_decay_steps: 100
#decay_steps: 10
adam_beta2: 0.998
max_grad_norm: 0
label_smoothing: 0.0
param_init: 0
param_init_glorot: true
normalization: "tokens"

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 8
lora_dropout: 0.05
lora_alpha: 16
lora_embedding: false

#8bit
quant_layers: ['w_1', 'w_2', 'w_3']
# Model
model_task: lm
decoder_type: transformer_lm
layer_norm: rms
pos_ffn_activation_fn: 'silu'
max_relative_positions: -1
position_encoding: false
add_qkvbias: false
dec_layers: 32
heads: 32
hidden_size: 4096
word_vec_size: 4096
transformer_ff: 11008
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]
1 Like

Thanks! With about same config itā€™s working good. I just changed lr to 0.0002. But after first save on step 200 and evaluation I got these errors:

[2023-05-08 12:13:07,546 INFO] Train perplexity: 3.67446
[2023-05-08 12:13:07,546 INFO] Train accuracy: 67.1748
[2023-05-08 12:13:07,546 INFO] Sentences processed: 34564
[2023-05-08 12:13:07,546 INFO] Average bsz:  767/ 767/ 5
[2023-05-08 12:13:07,546 INFO] Validation perplexity: 2.86707
[2023-05-08 12:13:07,546 INFO] Validation accuracy: 71.36
[2023-05-08 12:13:07,689 INFO] Saving checkpoint ready/llama7b-main.pt_step_200.pt
[2023-05-08 12:13:09,368 INFO] Step 201, cuda OOM - batch removed
[2023-05-08 12:13:09,484 INFO] Step 201, cuda OOM - batch removed
[2023-05-08 12:13:09,512 INFO] Step 201, cuda OOM - batch removed
.....
TypeError: multi_tensor_l2norm(): incompatible function arguments. The following argument types are supported:
    1. (arg0: int, arg1: torch.Tensor, arg2: List[List[torch.Tensor]], arg3: Optional[bool]) -> Tuple[torch.Tensor, torch.Tensor]

Invoked with: 65536, tensor([0], device='cuda:0', dtype=torch.int32), [[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]], True

And I canā€™t start again with saved checkpoint. Trying to change train_from to saved checkpoint and get:

Traceback (most recent call last):
  File "/opt/conda/bin/onmt_train", line 33, in <module>
    sys.exit(load_entry_point('OpenNMT-py==3.1.1', 'console_scripts', 'onmt_train')())
  File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/bin/train.py", line 65, in main
    train(opt)
  File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/bin/train.py", line 50, in train
    train_process(opt, device_id=0)
  File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/train_single.py", line 164, in main
    model = build_model(model_opt, opt, vocabs, checkpoint)
  File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/model_builder.py", line 414, in build_model
    model = build_base_model(model_opt, vocabs, checkpoint)
  File "/opt/conda/lib/python3.10/site-packages/OpenNMT_py-3.1.1-py3.10.egg/onmt/model_builder.py", line 385, in build_base_model
    if '0.weight' in checkpoint['generator']:
TypeError: argument of type 'NoneType' is not iterable

The first error results from too many OOM.
if you use a dataset that is not the same as mine, maybe try to reduce the batch size a bit, or maybe you have some other processes using your gpu.

anyway, you cannot train_from a LoRa checkpoint directly.
Either you start again (recommended because otherwise it will start from the beg of the dataset)
Or
you would need to use the tool lora_weights with --action concat
and train from the resulting merged checkpoint.

1 Like

Thanks. Looks like when saving checkpoint, some additional memory used and itā€™s not freeā€™d after. But anyways, nice work!

How to convert finetuned model to ctranslate2?

it wonā€™t work for now. It requires some changes in the Onmt-py => CT2 converter.
Youā€™ll have to be patient, inference works fine in -py.

1 Like

MosaicML released another 7B model that is more permissive in terms of usage.
Here is the blog page: Introducing MPT-7B: A New Standard for Open-Source, Commercially Usable LLMs

The architecture is slightly different compared to llama but we added those features in OpenNMT-py.

The first step is to convert the Hugging Face checkpoint format into OpenNMT-py format.

Use the following converter tools/convert_mpt.py from the repo.

Then download the bpe model and the vocab file here:
https://opennmt-models.s3.amazonaws.com/mosaic-MPT/mpt-model.bpe
https://opennmt-models.s3.amazonaws.com/mosaic-MPT/mpt.vocab
Those were created from the tokenizer.json file on the Hugging Face repo.

You will also need to get the Alpaca and sharegpt data file from the first post in this thread (same files as for llama finetuning).

Then you can use this config file and get running for finetuning with LoRa + 8bit loading:

# Corpus opts:
data:
    alpaca:
        path_src: "alpaca_clean.txt"
        transforms: [onmt_tokenize, filtertoolong]
        weight: 10
    sharegpt:
        path_src: "sharegpt.txt"
        transforms: [onmt_tokenize, filtertoolong]
        weight: 10

    valid:
        path_src: "valid.txt"
        transforms: [onmt_tokenize]

### Transform related opts:
#### Subword
src_subword_type: bpe
src_subword_model: "mpt-model.bpe"
src_onmttok_kwargs: '{"mode": "conservative"}'

tgt_subword_type: bpe
tgt_subword_model: "mpt-model.bpe"
tgt_onmttok_kwargs: '{"mode": "conservative"}'
gpt2_pretok: true

#### Filter
src_seq_length: 512
tgt_seq_length: 512

# silently ignore empty lines in the data
skip_empty_level: silent

# General opts
train_from: "mpt7b-onmt.pt"
save_model: "/mpt7B/mpt7B-vicuna-onmt"
keep_checkpoint: 10
save_checkpoint_steps: 400
seed: 1234
report_every: 10
train_steps: 4000
valid_steps: 400

# Batching
bucket_size: 32768
#bucket_size: 1
num_workers: 2
world_size: 1
gpu_ranks: [0]
batch_type: "tokens"
batch_size: 896
valid_batch_size: 256
batch_size_multiple: 1
accum_count: [32]
accum_steps: [0]

override_opts: true  # CAREFULL this requires all settings to be defined below

share_vocab: true
save_data: "/dataAI"
src_vocab: "mpt.vocab"
src_vocab_size: 50432
tgt_vocab_size: 50432
default_specials: ['</s>', '<blank>']

decoder_start_token: '</s>'
# Optimization
model_dtype: "fp16"
apex_opt_level: ""
optim: "fusedadam"
learning_rate: 0.0002
warmup_steps: 100
decay_method: "none"
#learning_rate_decay: 0.98
#start_decay_steps: 100
#decay_steps: 10
adam_beta2: 0.998
max_grad_norm: 0
label_smoothing: 0.0
param_init: 0
param_init_glorot: true
normalization: "tokens"

#LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 8
lora_dropout: 0.05
lora_alpha: 16
lora_embedding: false

#8bit
quant_layers: ['w_1', 'w_2']
# Model
model_task: lm
decoder_type: transformer_lm
layer_norm: standard
pos_ffn_activation_fn: 'gelu'
max_relative_positions: -2
position_encoding: false
add_qkvbias: false
dec_layers: 32
heads: 32
hidden_size: 4096
word_vec_size: 4096
transformer_ff: 16384
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]

It took 8-9 hours on my RTX4090

Then you need to merge the LoRa weights into the base model (same as for llama).

Inference is also similar, just make sure you set up onmt_tokenize and the bpe model instead of sentencepiece.

transforms: [onmt_tokenize]

#### Subword
src_subword_type: bpe
src_subword_model: "mpt-model.bpe"
src_onmttok_kwargs: '{"mode": "conservative"}'

tgt_subword_type: bpe
tgt_subword_model: "mpt-model.bpe"
tgt_onmttok_kwargs: '{"mode": "conservative"}'
**gpt2_pretok: true**
# Model info
model: "mpt7B/mpt7B-vicuna-merged-onmt_step_4000.pt"

# Inference
seed: 42
max_length: 512
gpu: 0
batch_type: sents
batch_size: 1
precision: fp16
random_sampling_topk: 40
random_sampling_topp: 0.75
random_sampling_temp: 0.1
beam_size: 1
report_time: true

Output is very similar to llama finetuned.

DISCLAIMER:
While the MPT-7B is more permissive (commercial usage allowed) it is unclear whether the alpaca / sharegpt datasets are allowed for commercial usage. For Alpaca, it seems that they have been generated through the OpenAI API which restricts the downstream usage. Sharegpt seems to be ChatGPT web output for which the TOS is different.

The best would be to finetune using the OpenAssistant dataset.

3 Likes

Step-by-step tuto for ā€œVicunaā€ replication.

1 Like