diff --git a/modules/lfs/content_store.go b/modules/lfs/content_store.go index 247191a1bf..788ef5b9a6 100644 --- a/modules/lfs/content_store.go +++ b/modules/lfs/content_store.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash" "io" "os" @@ -66,15 +67,20 @@ func (s *ContentStore) Get(meta *models.LFSMetaObject, fromByte int64) (io.ReadC // Put takes a Meta object and an io.Reader and writes the content to the store. func (s *ContentStore) Put(meta *models.LFSMetaObject, r io.Reader) error { - hash := sha256.New() - rd := io.TeeReader(r, hash) p := meta.RelativePath() - written, err := s.Save(p, rd) + + // Wrap the provided reader with an inline hashing and size checker + wrappedRd := newHashingReader(meta.Size, meta.Oid, r) + + // now pass the wrapped reader to Save - if there is a size mismatch or hash mismatch then + // the errors returned by the newHashingReader should percolate up to here + written, err := s.Save(p, wrappedRd) if err != nil { log.Error("Whilst putting LFS OID[%s]: Failed to copy to tmpPath: %s Error: %v", meta.Oid, p, err) return err } + // This shouldn't happen but it is sensible to test if written != meta.Size { if err := s.Delete(p); err != nil { log.Error("Cleaning the LFS OID[%s] failed: %v", meta.Oid, err) @@ -82,14 +88,6 @@ func (s *ContentStore) Put(meta *models.LFSMetaObject, r io.Reader) error { return errSizeMismatch } - shaStr := hex.EncodeToString(hash.Sum(nil)) - if shaStr != meta.Oid { - if err := s.Delete(p); err != nil { - log.Error("Cleaning the LFS OID[%s] failed: %v", meta.Oid, err) - } - return errHashMismatch - } - return nil } @@ -118,3 +116,45 @@ func (s *ContentStore) Verify(meta *models.LFSMetaObject) (bool, error) { return true, nil } + +type hashingReader struct { + internal io.Reader + currentSize int64 + expectedSize int64 + hash hash.Hash + expectedHash string +} + +func (r *hashingReader) Read(b []byte) (int, error) { + n, err := r.internal.Read(b) + + if n > 0 { + r.currentSize += int64(n) + wn, werr := r.hash.Write(b[:n]) + if wn != n || werr != nil { + return n, werr + } + } + + if err != nil && err == io.EOF { + if r.currentSize != r.expectedSize { + return n, errSizeMismatch + } + + shaStr := hex.EncodeToString(r.hash.Sum(nil)) + if shaStr != r.expectedHash { + return n, errHashMismatch + } + } + + return n, err +} + +func newHashingReader(expectedSize int64, expectedHash string, reader io.Reader) *hashingReader { + return &hashingReader{ + internal: reader, + expectedSize: expectedSize, + expectedHash: expectedHash, + hash: sha256.New(), + } +}