Update 5.在无标记数据集上进行预训练.md

This commit is contained in:
yuhui
2025-05-09 19:59:55 +08:00
committed by GitHub
parent 5ed44ac984
commit f399a51b79

View File

@@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size,
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
if idx_next == eos_id: #E
break
idx_next = idx_next.unsqueeze(1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
return idx
#A For循环与之前相同获取logits仅关注最后的时间步