Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问模型怎么才能通过deepspeed进行多卡训练 #22

Open
RayneSun opened this issue Jul 12, 2023 · 5 comments
Open

请问模型怎么才能通过deepspeed进行多卡训练 #22

RayneSun opened this issue Jul 12, 2023 · 5 comments

Comments

@RayneSun
Copy link

如题

@shuxueslpi
Copy link
Owner

暂时还有点问题,我也在调试,会尽快更新

@yyqi17
Copy link

yyqi17 commented Jul 26, 2023

以下是修改后跑通deepspeed单机多卡的主要替换代码(替换 trainer=LoRATrainer 及之后的部分):

model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(config=conf,
                                                                    model=model,
                                                                    model_parameters=model.parameters(),
                                                                    training_data=train_dataset,
                                                                    collate_fn=coll_fn)
model_engine.train()
for i_epoch in range(global_args.num_train_epochs):
    for micro_step, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(model_engine.local_rank)
        labels = batch["labels"].to(model_engine.local_rank)
        
        outputs = model_engine.forward(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        
        model_engine.backward(loss)
        model_engine.step()

    save_dir = f'{global_args.output_dir}/{i_epoch}'
    model_engine.save_pretrained(save_dir)

补充:

  1. 这里coll_fn用原始的DataCollatorForChatGLM会有问题,coll_fn是一个单独的函数(类似DataCollatorForChatGLM.call
  2. model加载时复用了官方脚本里的加载方式

最后train.sh里的python改成deepspeed启动就可以了

@yyqi17
Copy link

yyqi17 commented Oct 13, 2023

以下是修改后跑通deepspeed单机多卡的主要替换代码(替换 trainer=LoRATrainer 及之后的部分):

model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(config=conf,
                                                                    model=model,
                                                                    model_parameters=model.parameters(),
                                                                    training_data=train_dataset,
                                                                    collate_fn=coll_fn)
model_engine.train()
for i_epoch in range(global_args.num_train_epochs):
    for micro_step, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(model_engine.local_rank)
        labels = batch["labels"].to(model_engine.local_rank)
        
        outputs = model_engine.forward(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        
        model_engine.backward(loss)
        model_engine.step()

    save_dir = f'{global_args.output_dir}/{i_epoch}'
    model_engine.save_pretrained(save_dir)

补充:

  1. 这里coll_fn用原始的DataCollatorForChatGLM会有问题,coll_fn是一个单独的函数(类似DataCollatorForChatGLM.call
  2. model加载时复用了官方脚本里的加载方式

最后train.sh里的python改成deepspeed启动就可以了

这个conf是lora_config吗

不是,conf是deepspeed的配置,比如像下面这样

conf = {"train_micro_batch_size_per_gpu": args.per_device_train_batch_size,
      "gradient_accumulation_steps": args.gradient_accumulation_steps,
      "gradient_clipping": 1.0,
      "optimizer": {
          "type": "Adam",
          "params": {
              "lr": args.learning_rate,
              "betas": [
                  0.9,
                  0.95
              ],
              "eps": 1e-8,
              "weight_decay": args.weight_decay
          }
      },
      "fp16": {
          "enabled": False
      },
      "zero_optimization": {
          "stage": args.zero_stage,
          "offload_optimizer": {
              "device": "cpu",
              "pin_memory": True
          },
          "allgather_partitions": True,
          "allgather_bucket_size": 2e8,
          "overlap_comm": True,
          "reduce_scatter": True,
          "reduce_bucket_size": 2e8,
          "contiguous_gradients": True
      },
  }

@WellWang-S
Copy link

以下是修改后跑通deepspeed单机多卡的主要替换代码(替换 trainer=LoRATrainer 及之后的部分):

model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(config=conf,
                                                                    model=model,
                                                                    model_parameters=model.parameters(),
                                                                    training_data=train_dataset,
                                                                    collate_fn=coll_fn)
model_engine.train()
for i_epoch in range(global_args.num_train_epochs):
    for micro_step, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(model_engine.local_rank)
        labels = batch["labels"].to(model_engine.local_rank)
        
        outputs = model_engine.forward(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        
        model_engine.backward(loss)
        model_engine.step()

    save_dir = f'{global_args.output_dir}/{i_epoch}'
    model_engine.save_pretrained(save_dir)

补充:

  1. 这里coll_fn用原始的DataCollatorForChatGLM会有问题,coll_fn是一个单独的函数(类似DataCollatorForChatGLM.call
  2. model加载时复用了官方脚本里的加载方式

最后train.sh里的python改成deepspeed启动就可以了

多卡训练会报错,untimeError: Expected all tensors to be on the same device, but found at least two devices,你有遇到吗

@yyqi17
Copy link

yyqi17 commented Nov 9, 2023

以下是修改后跑通deepspeed单机多卡的主要替换代码(替换 trainer=LoRATrainer 及之后的部分):

model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(config=conf,
                                                                    model=model,
                                                                    model_parameters=model.parameters(),
                                                                    training_data=train_dataset,
                                                                    collate_fn=coll_fn)
model_engine.train()
for i_epoch in range(global_args.num_train_epochs):
    for micro_step, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(model_engine.local_rank)
        labels = batch["labels"].to(model_engine.local_rank)
        
        outputs = model_engine.forward(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        
        model_engine.backward(loss)
        model_engine.step()

    save_dir = f'{global_args.output_dir}/{i_epoch}'
    model_engine.save_pretrained(save_dir)

补充:

  1. 这里coll_fn用原始的DataCollatorForChatGLM会有问题,coll_fn是一个单独的函数(类似DataCollatorForChatGLM.call
  2. model加载时复用了官方脚本里的加载方式

最后train.sh里的python改成deepspeed启动就可以了

多卡训练会报错,untimeError: Expected all tensors to be on the same device, but found at least two devices,你有遇到吗

我遇到的时候这个报错是来自于model加载部分,也就是在这块代码之前model=xxxModel()那里,或许可以看一下model_device_map是不是正确的

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants