Skip to content
Snippets Groups Projects
sql.go 103 KiB
Newer Older
  • Learn to ignore specific revisions
  • // Next prepares the next result row for reading with the [Rows.Scan] method. It
    
    // returns true on success, or false if there is no next result row or an error
    
    // happened while preparing it. [Rows.Err] should be consulted to distinguish between
    
    // Every call to [Rows.Scan], even the first one, must be preceded by a call to [Rows.Next].
    
    func (rs *Rows) Next() bool {
    
    	// If the user's calling Next, they're done with their previous row's Scan
    	// results (any RawBytes memory), so we can release the read lock that would
    	// be preventing awaitDone from calling close.
    	rs.closemuRUnlockIfHeldByScan()
    
    	if rs.contextDone.Load() != nil {
    		return false
    	}
    
    
    	var doClose, ok bool
    	withLock(rs.closemu.RLocker(), func() {
    		doClose, ok = rs.nextLocked()
    	})
    	if doClose {
    		rs.Close()
    	}
    
    	if doClose && !ok {
    		rs.hitEOF = true
    	}
    
    	return ok
    }
    
    func (rs *Rows) nextLocked() (doClose, ok bool) {
    	if rs.closed {
    		return false, false
    
    
    	// Lock the driver connection before calling the driver interface
    	// rowsi to prevent a Tx from rolling back the connection at the same time.
    	rs.dc.Lock()
    	defer rs.dc.Unlock()
    
    
    	if rs.lastcols == nil {
    
    		rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
    
    	rs.lasterr = rs.rowsi.Next(rs.lastcols)
    
    	if rs.lasterr != nil {
    		// Close the connection if there is a driver error.
    		if rs.lasterr != io.EOF {
    
    		}
    		nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
    		if !ok {
    
    		}
    		// The driver is at the end of the current result set.
    		// Test to see if there is another result set after the current one.
    
    shawnps's avatar
    shawnps committed
    		// Only close Rows if there is no further result sets to read.
    
    		if !nextResultSet.HasNextResultSet() {
    
    // NextResultSet prepares the next result set for reading. It reports whether
    
    // there is further result sets, or false if there is no further result set
    
    // or if there is an error advancing to it. The [Rows.Err] method should be consulted
    
    // to distinguish between the two cases.
    //
    
    // After calling NextResultSet, the [Rows.Next] method should always be called before
    
    // scanning. If there are further result sets they may not have rows in the result
    // set.
    func (rs *Rows) NextResultSet() bool {
    
    	// If the user's calling NextResultSet, they're done with their previous
    	// row's Scan results (any RawBytes memory), so we can release the read lock
    	// that would be preventing awaitDone from calling close.
    	rs.closemuRUnlockIfHeldByScan()
    
    
    	var doClose bool
    	defer func() {
    		if doClose {
    			rs.Close()
    		}
    	}()
    	rs.closemu.RLock()
    	defer rs.closemu.RUnlock()
    
    	if rs.closed {
    
    	rs.lastcols = nil
    	nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
    	if !ok {
    
    
    	// Lock the driver connection before calling the driver interface
    	// rowsi to prevent a Tx from rolling back the connection at the same time.
    	rs.dc.Lock()
    	defer rs.dc.Unlock()
    
    
    	rs.lasterr = nextResultSet.NextResultSet()
    
    // Err returns the error, if any, that was encountered during iteration.
    
    // Err may be called after an explicit or implicit [Rows.Close].
    
    	// Return any context error that might've happened during row iteration,
    	// but only if we haven't reported the final Next() = false after rows
    	// are done, in which case the user might've canceled their own context
    	// before calling Rows.Err.
    	if !rs.hitEOF {
    		if errp := rs.contextDone.Load(); errp != nil {
    			return *errp
    		}
    
    	rs.closemu.RLock()
    	defer rs.closemu.RUnlock()
    
    // rawbuf returns the buffer to append RawBytes values to.
    // This buffer is reused across calls to Rows.Scan.
    //
    // Usage:
    //
    //	rawBytes = rows.setrawbuf(append(rows.rawbuf(), value...))
    func (rs *Rows) rawbuf() []byte {
    	if rs == nil {
    		// convertAssignRows can take a nil *Rows; for simplicity handle it here
    		return nil
    	}
    	return rs.raw
    }
    
    // setrawbuf updates the RawBytes buffer with the result of appending a new value to it.
    // It returns the new value.
    func (rs *Rows) setrawbuf(b []byte) RawBytes {
    	if rs == nil {
    		// convertAssignRows can take a nil *Rows; for simplicity handle it here
    		return RawBytes(b)
    	}
    	off := len(rs.raw)
    	rs.raw = b
    	return RawBytes(rs.raw[off:])
    }
    
    
    var errRowsClosed = errors.New("sql: Rows are closed")
    var errNoRows = errors.New("sql: no Rows available")
    
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    // Columns returns the column names.
    
    // Columns returns an error if the rows are closed.
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    func (rs *Rows) Columns() ([]string, error) {
    
    	rs.closemu.RLock()
    	defer rs.closemu.RUnlock()
    	if rs.closed {
    
    		return nil, rs.lasterrOrErrLocked(errRowsClosed)
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    	}
    	if rs.rowsi == nil {
    
    		return nil, rs.lasterrOrErrLocked(errNoRows)
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    	return rs.rowsi.Columns(), nil
    }
    
    
    // ColumnTypes returns column information such as column type, length,
    // and nullable. Some information may not be available from some drivers.
    func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
    
    	rs.closemu.RLock()
    	defer rs.closemu.RUnlock()
    	if rs.closed {
    
    		return nil, rs.lasterrOrErrLocked(errRowsClosed)
    
    		return nil, rs.lasterrOrErrLocked(errNoRows)
    
    	rs.dc.Lock()
    	defer rs.dc.Unlock()
    
    	return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
    
    }
    
    // ColumnType contains the name and type of a column.
    type ColumnType struct {
    	name string
    
    	hasNullable       bool
    	hasLength         bool
    	hasPrecisionScale bool
    
    	nullable     bool
    	length       int64
    	databaseType string
    	precision    int64
    	scale        int64
    	scanType     reflect.Type
    }
    
    // Name returns the name or alias of the column.
    func (ci *ColumnType) Name() string {
    	return ci.name
    }
    
    // Length returns the column type length for variable length column types such
    // as text and binary field types. If the type length is unbounded the value will
    
    // be [math.MaxInt64] (any database limits will still apply).
    
    // If the column type is not variable length, such as an int, or if not supported
    // by the driver ok is false.
    func (ci *ColumnType) Length() (length int64, ok bool) {
    	return ci.length, ci.hasLength
    }
    
    // DecimalSize returns the scale and precision of a decimal type.
    // If not applicable or if not supported ok is false.
    func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
    	return ci.precision, ci.scale, ci.hasPrecisionScale
    }
    
    
    // ScanType returns a Go type suitable for scanning into using [Rows.Scan].
    
    // If a driver does not support this property ScanType will return
    // the type of an empty interface.
    func (ci *ColumnType) ScanType() reflect.Type {
    	return ci.scanType
    }
    
    
    // Nullable reports whether the column may be null.
    
    // If a driver does not support this property ok will be false.
    func (ci *ColumnType) Nullable() (nullable, ok bool) {
    	return ci.nullable, ci.hasNullable
    }
    
    // DatabaseTypeName returns the database system name of the column type. If an empty
    
    // string is returned, then the driver type name is not supported.
    
    // Consult your driver documentation for a list of driver data types. [ColumnType.Length] specifiers
    
    // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
    // "INT", and "BIGINT".
    
    func (ci *ColumnType) DatabaseTypeName() string {
    	return ci.databaseType
    }
    
    
    func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
    
    	names := rowsi.Columns()
    
    	list := make([]*ColumnType, len(names))
    	for i := range list {
    		ci := &ColumnType{
    			name: names[i],
    		}
    		list[i] = ci
    
    		if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
    			ci.scanType = prop.ColumnTypeScanType(i)
    		} else {
    
    			ci.scanType = reflect.TypeFor[any]()
    
    		}
    		if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
    			ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
    		}
    		if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
    			ci.length, ci.hasLength = prop.ColumnTypeLength(i)
    		}
    		if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
    			ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
    		}
    		if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
    			ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
    		}
    	}
    	return list
    }
    
    
    // Scan copies the columns in the current row into the values pointed
    
    // at by dest. The number of values in dest must be the same as the
    
    // number of columns in [Rows].
    
    // Scan converts columns read from the database into the following
    // common Go types and special types provided by the sql package:
    //
    
    Russ Cox's avatar
    Russ Cox committed
    //	*string
    //	*[]byte
    //	*int, *int8, *int16, *int32, *int64
    //	*uint, *uint8, *uint16, *uint32, *uint64
    //	*bool
    //	*float32, *float64
    //	*interface{}
    
    //	*RawBytes
    //	*Rows (cursor value)
    //	any type implementing Scanner (see Scanner docs)
    
    //
    // In the most simple case, if the type of the value from the source
    // column is an integer, bool or string type T and dest is of type *T,
    // Scan simply assigns the value through the pointer.
    //
    // Scan also converts between string and numeric types, as long as no
    // information would be lost. While Scan stringifies all numbers
    // scanned from numeric database columns into *string, scans into
    // numeric types are checked for overflow. For example, a float64 with
    // value 300 or a string with value "300" can scan into a uint16, but
    // not into a uint8, though float64(255) or "255" can scan into a
    // uint8. One exception is that scans of some float64 numbers to
    // strings may lose information when stringifying. In general, scan
    // floating point columns into *float64.
    //
    // If a dest argument has type *[]byte, Scan saves in that argument a
    // copy of the corresponding data. The copy is owned by the caller and
    // can be modified and held indefinitely. The copy can be avoided by
    
    // using an argument of type [*RawBytes] instead; see the documentation
    
    // for [RawBytes] for restrictions on its use.
    
    //
    // If an argument has type *interface{}, Scan copies the value
    
    // provided by the underlying driver without conversion. When scanning
    // from a source value of type []byte to *interface{}, a copy of the
    // slice is made and the caller owns the result.
    //
    
    // Source values of type [time.Time] may be scanned into values of type
    
    // *time.Time, *interface{}, *string, or *[]byte. When converting to
    
    // the latter two, [time.RFC3339Nano] is used.
    
    //
    // Source values of type bool may be scanned into types *bool,
    
    // *interface{}, *string, *[]byte, or [*RawBytes].
    
    //
    // For scanning into *bool, the source may be true, false, 1, 0, or
    
    // string inputs parseable by [strconv.ParseBool].
    
    //
    // Scan can also convert a cursor returned from a query, such as
    // "select cursor(select * from my_table) from dual", into a
    
    // [*Rows] value that can itself be scanned from. The parent
    // select query will close any cursor [*Rows] if the parent [*Rows] is closed.
    
    // If any of the first arguments implementing [Scanner] returns an error,
    
    // that error will be wrapped in the returned error.
    
    func (rs *Rows) Scan(dest ...any) error {
    
    	if rs.closemuScanHold {
    		// This should only be possible if the user calls Scan twice in a row
    		// without calling Next.
    		return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
    	}
    
    
    	if rs.lasterr != nil && rs.lasterr != io.EOF {
    		rs.closemu.RUnlock()
    		return rs.lasterr
    	}
    
    		err := rs.lasterrOrErrLocked(errRowsClosed)
    
    
    	if scanArgsContainRawBytes(dest) {
    		rs.closemuScanHold = true
    
    	if rs.lastcols == nil {
    
    		rs.closemuRUnlockIfHeldByScan()
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    		return errors.New("sql: Scan called without calling Next")
    
    	}
    	if len(dest) != len(rs.lastcols) {
    
    		rs.closemuRUnlockIfHeldByScan()
    
    Brad Fitzpatrick's avatar
    Brad Fitzpatrick committed
    		return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
    
    	for i, sv := range rs.lastcols {
    
    		err := convertAssignRows(dest[i], sv, rs)
    
    			rs.closemuRUnlockIfHeldByScan()
    
    			return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
    
    // closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
    // call to Scan with *RawBytes.
    func (rs *Rows) closemuRUnlockIfHeldByScan() {
    	if rs.closemuScanHold {
    		rs.closemuScanHold = false
    		rs.closemu.RUnlock()
    	}
    }
    
    func scanArgsContainRawBytes(args []any) bool {
    	for _, a := range args {
    		if _, ok := a.(*RawBytes); ok {
    			return true
    		}
    	}
    	return false
    }
    
    
    // rowsCloseHook returns a function so tests may install the
    
    // hook through a test only mutex.
    
    var rowsCloseHook = func() func(*Rows, *error) { return nil }
    
    // Close closes the [Rows], preventing further enumeration. If [Rows.Next] is called
    
    // and returns false and there are no further result sets,
    
    // the [Rows] are closed automatically and it will suffice to check the
    // result of [Rows.Err]. Close is idempotent and does not affect the result of [Rows.Err].
    
    func (rs *Rows) Close() error {
    
    	// If the user's calling Close, they're done with their previous row's Scan
    	// results (any RawBytes memory), so we can release the read lock that would
    	// be preventing awaitDone from calling the unexported close before we do so.
    	rs.closemuRUnlockIfHeldByScan()
    
    
    	return rs.close(nil)
    }
    
    func (rs *Rows) close(err error) error {
    	rs.closemu.Lock()
    	defer rs.closemu.Unlock()
    
    	if rs.closed {
    
    	rs.closed = true
    
    	if rs.lasterr == nil {
    		rs.lasterr = err
    	}
    
    	withLock(rs.dc, func() {
    		err = rs.rowsi.Close()
    	})
    
    	if fn := rowsCloseHook(); fn != nil {
    
    	if rs.closeStmt != nil {
    		rs.closeStmt.Close()
    	}
    
    
    	rs.lasterr = rs.lasterrOrErrLocked(err)
    
    // Row is the result of calling [DB.QueryRow] to select a single row.
    
    type Row struct {
    	// One of these two will be non-nil:
    
    	err  error // deferred error for easy chaining
    
    	rows *Rows
    }
    
    // Scan copies the columns from the matched row into the values
    
    // pointed at by dest. See the documentation on [Rows.Scan] for details.
    
    // If more than one row matches the query,
    // Scan uses the first row and discards the rest. If no row matches
    
    // the query, Scan returns [ErrNoRows].
    
    func (r *Row) Scan(dest ...any) error {
    
    	if r.err != nil {
    		return r.err
    	}
    
    
    	// TODO(bradfitz): for now we need to defensively clone all
    
    	// []byte that the driver returned (not permitting
    
    	// *RawBytes in Rows.Scan), since we're about to close
    
    	// the Rows in our defer, when we return from this function.
    	// the contract with the driver.Next(...) interface is that it
    	// can return slices into read-only temporary memory that's
    
    	// only valid until the next Scan/Close. But the TODO is that
    	// for a lot of drivers, this copy will be unnecessary. We
    
    	// should provide an optional interface for drivers to
    	// implement to say, "don't worry, the []bytes that I return
    	// from Next will not be modified again." (for instance, if
    	// they were obtained from the network anyway) But for now we
    	// don't care.
    
    	if scanArgsContainRawBytes(dest) {
    		return errors.New("sql: RawBytes isn't allowed on Row.Scan")
    
    		if err := r.rows.Err(); err != nil {
    			return err
    		}
    
    		return ErrNoRows
    	}
    	err := r.rows.Scan(dest...)
    	if err != nil {
    		return err
    	}
    
    	// Make sure the query can be processed to completion with no errors.
    
    	return r.rows.Close()
    
    // Err provides a way for wrapping packages to check for
    
    // query errors without calling [Row.Scan].
    
    // Err returns the error, if any, that was encountered while running the query.
    
    // If this error is not nil, this error will also be returned from [Row.Scan].
    
    func (r *Row) Err() error {
    	return r.err
    }
    
    
    // A Result summarizes an executed SQL command.
    type Result interface {
    
    	// LastInsertId returns the integer generated by the database
    	// in response to a command. Typically this will be from an
    	// "auto increment" column when inserting a new row. Not all
    	// databases support this feature, and the syntax of such
    	// statements varies.
    
    	LastInsertId() (int64, error)
    
    
    	// RowsAffected returns the number of rows affected by an
    	// update, insert, or delete. Not every database or database
    	// driver may support this.
    
    	RowsAffected() (int64, error)
    
    type driverResult struct {
    	sync.Locker // the *driverConn
    	resi        driver.Result
    }
    
    func (dr driverResult) LastInsertId() (int64, error) {
    	dr.Lock()
    	defer dr.Unlock()
    	return dr.resi.LastInsertId()
    }
    
    func (dr driverResult) RowsAffected() (int64, error) {
    	dr.Lock()
    	defer dr.Unlock()
    	return dr.resi.RowsAffected()
    
    	return string(buf[:runtime.Stack(buf[:], false)])
    }
    
    
    // withLock runs while holding lk.
    func withLock(lk sync.Locker, fn func()) {
    	lk.Lock()
    
    	defer lk.Unlock() // in case fn panics
    
    
    // connRequestSet is a set of chan connRequest that's
    // optimized for:
    //
    //   - adding an element
    //   - removing an element (only by the caller who added it)
    //   - taking (get + delete) a random element
    //
    // We previously used a map for this but the take of a random element
    // was expensive, making mapiters. This type avoids a map entirely
    // and just uses a slice.
    type connRequestSet struct {
    	// s are the elements in the set.
    	s []connRequestAndIndex
    }
    
    type connRequestAndIndex struct {
    	// req is the element in the set.
    	req chan connRequest
    
    	// curIdx points to the current location of this element in
    	// connRequestSet.s. It gets set to -1 upon removal.
    	curIdx *int
    }
    
    // CloseAndRemoveAll closes all channels in the set
    // and clears the set.
    func (s *connRequestSet) CloseAndRemoveAll() {
    	for _, v := range s.s {
    		close(v.req)
    	}
    	s.s = nil
    }
    
    // Len returns the length of the set.
    func (s *connRequestSet) Len() int { return len(s.s) }
    
    // connRequestDelHandle is an opaque handle to delete an
    // item from calling Add.
    type connRequestDelHandle struct {
    	idx *int // pointer to index; or -1 if not in slice
    }
    
    // Add adds v to the set of waiting requests.
    // The returned connRequestDelHandle can be used to remove the item from
    // the set.
    func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
    	idx := len(s.s)
    	// TODO(bradfitz): for simplicity, this always allocates a new int-sized
    	// allocation to store the index. But generally the set will be small and
    	// under a scannable-threshold. As an optimization, we could permit the *int
    	// to be nil when the set is small and should be scanned. This works even if
    	// the set grows over the threshold with delete handles outstanding because
    	// an element can only move to a lower index. So if it starts with a nil
    	// position, it'll always be in a low index and thus scannable. But that
    	// can be done in a follow-up change.
    	idxPtr := &idx
    	s.s = append(s.s, connRequestAndIndex{v, idxPtr})
    	return connRequestDelHandle{idxPtr}
    }
    
    // Delete removes an element from the set.
    //
    // It reports whether the element was deleted. (It can return false if a caller
    // of TakeRandom took it meanwhile, or upon the second call to Delete)
    func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
    	idx := *h.idx
    	if idx < 0 {
    		return false
    	}
    	s.deleteIndex(idx)
    	return true
    }
    
    func (s *connRequestSet) deleteIndex(idx int) {
    	// Mark item as deleted.
    	*(s.s[idx].curIdx) = -1
    	// Copy last element, updating its position
    	// to its new home.
    	if idx < len(s.s)-1 {
    		last := s.s[len(s.s)-1]
    		*last.curIdx = idx
    		s.s[idx] = last
    	}
    	// Zero out last element (for GC) before shrinking the slice.
    	s.s[len(s.s)-1] = connRequestAndIndex{}
    	s.s = s.s[:len(s.s)-1]
    }
    
    // TakeRandom returns and removes a random element from s
    // and reports whether there was one to take. (It returns ok=false
    // if the set is empty.)
    func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
    	if len(s.s) == 0 {
    		return nil, false
    	}
    	pick := rand.IntN(len(s.s))
    	e := s.s[pick]
    	s.deleteIndex(pick)
    	return e.req, true
    }