On peut simplement prendre une image PNG de référence pour venir coller les bons octets à la place de ceux-ci.
flag = open('flag.png', 'rb').read()
sample = open('sample.png', 'rb').read()
# On voit que le début est différent
print(flag[:12].hex()) # 00000000000000000000000d
print(sample[:12].hex()) # 89504e470d0a1a0a0000000d
with open('flag_corrected.png', 'wb') as f:
# On prend les 12 premiers octets d'un fichier valide
f.write(sample[:12])
# Puis le reste de notre fichier flag
f.write(flag[12:])