mirror of
https://github.com/skindhu/Build-A-Large-Language-Model-CN.git
synced 2026-07-01 01:10:17 +08:00
@@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size,
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
if idx_next == eos_id: #E
|
||||
break
|
||||
idx_next = idx_next.unsqueeze(1)
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
return idx
|
||||
return idx
|
||||
|
||||
|
||||
#A For循环与之前相同:获取logits,仅关注最后的时间步
|
||||
|
||||
Reference in New Issue
Block a user