I have followed this guide as closely as possible: https://github.com/kingoflolz/mesh-transformer-jax
I'm trying to fine-tune GPT-J with a small dataset of ~500 lines:
You are important to me. <|endoftext|>I love spending time with you. <|endoftext|>You make me smile. <|endoftext|>feel so lucky to be your friend. <|endoftext|>You can always talk to me, even if it’s about something that makes you nervous or scared or sad. <|endoftext|>etc...Using the create_finetune_tfrecords.py script (from the repo mentioned above) outputs a file with 2 in it. I understand that means my data has 2 sequences.
I could really use some advice with the .json config file. What hyperparameters do you recommend for this small dataset?
The best I came up with trying to follow the guide:
{ "layers": 28, "d_model": 4096, "n_heads": 16, "n_vocab": 50400, "norm": "layernorm", "pe": "rotary", "pe_rotary_dims": 64, "seq": 2048, "cores_per_replica": 8, "per_replica_batch": 1, "gradient_accumulation_steps": 2, "warmup_steps": 1, "anneal_steps": 9, "lr": 1.2e-4, "end_lr": 1.2e-5, "weight_decay": 0.1, "total_steps": 10, "tpu_size": 8, "bucket": "chat-app-tpu-bucket-europe", "model_dir": "finetune_dir", "train_set": "james_bond_1.train.index", "val_set": {}, "eval_harness_tasks": [ ], "val_batches": 2, "val_every": 400000, "ckpt_every": 1, "keep_every": 1, "name": "GPT3_6B_pile_rotary", "wandb_project": "mesh-transformer-jax", "comment": ""}The problem is that, when I test the fine-tuned model, I get responses that make no sense:

