Fix kontext finetune issue when batch size >1 (#11921)
Some checks failed
Build documentation / build (push) Has been cancelled
Run dependency tests / check_dependencies (push) Has been cancelled
Run Flax dependency tests / check_flax_dependencies (push) Has been cancelled
Run Torch dependency tests / check_torch_dependencies (push) Has been cancelled
Fast GPU Tests on main / Setup Torch Pipelines CUDA Slow Tests Matrix (push) Has been cancelled
Fast GPU Tests on main / Torch CUDA Tests (lora) (push) Has been cancelled
Fast GPU Tests on main / Torch CUDA Tests (models) (push) Has been cancelled
Fast GPU Tests on main / Torch CUDA Tests (others) (push) Has been cancelled
Fast GPU Tests on main / Torch CUDA Tests (schedulers) (push) Has been cancelled
Fast GPU Tests on main / Torch CUDA Tests (single_file) (push) Has been cancelled
Fast GPU Tests on main / PyTorch Compile CUDA tests (push) Has been cancelled
Fast GPU Tests on main / PyTorch xformers CUDA tests (push) Has been cancelled
Fast GPU Tests on main / Examples PyTorch CUDA tests on Ubuntu (push) Has been cancelled
Fast tests on main / ${{ matrix.config.name }} (map[framework:pytorch image:diffusers/diffusers-pytorch-cpu name:Fast PyTorch CPU tests on Ubuntu report:torch_cpu runner:aws-general-8-plus]) (push) Has been cancelled
Fast tests on main / ${{ matrix.config.name }} (map[framework:pytorch_examples image:diffusers/diffusers-pytorch-cpu name:PyTorch Example CPU tests on Ubuntu report:torch_example_cpu runner:aws-general-8-plus]) (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Diffusers metadata / update_metadata (push) Has been cancelled
Fast GPU Tests on main / Torch Pipelines CUDA Tests (push) Has been cancelled

set drop_last to True

Signed-off-by: mymusise <mymusise1@gmail.com>
This commit is contained in:
Chengxi Guo
2025-07-19 07:38:58 +08:00
committed by GitHub
parent 5dc503aa28
commit cde02b061b

View File

@@ -1614,7 +1614,7 @@ def main(args):
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,