[ARVADOS] updated: 1.2.0-188-gd59ed50f6
Git user
git at public.curoverse.com
Thu Oct 18 17:31:49 EDT 2018
Summary of changes:
lib/controller/fed_collections.go | 312 +++++++++++++++++++
lib/controller/fed_generic.go | 331 ++++++++++++++++++++
lib/controller/federation.go | 640 +-------------------------------------
lib/controller/handler.go | 16 +-
lib/controller/proxy.go | 65 ++--
5 files changed, 683 insertions(+), 681 deletions(-)
create mode 100644 lib/controller/fed_collections.go
create mode 100644 lib/controller/fed_generic.go
via d59ed50f6c87d8ce9545df786d506337b74ebafd (commit)
via 0c99017642e413140b5315ddae8a99a7fb44a293 (commit)
via 748c5e85538b145e51777fe1015b943546d9ca06 (commit)
from bdd226bc9ab3f9f957963171b8d9c3d4355f472c (commit)
Those revisions listed above that are new to this repository have
not appeared on any other notification email; so we list those
revisions in full, below.
commit d59ed50f6c87d8ce9545df786d506337b74ebafd
Author: Peter Amstutz <pamstutz at veritasgenetics.com>
Date: Thu Oct 18 17:17:33 2018 -0400
14262: Refactoring, split up federation code into smaller files
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz at veritasgenetics.com>
diff --git a/lib/controller/fed_collections.go b/lib/controller/fed_collections.go
new file mode 100644
index 000000000..62f98367c
--- /dev/null
+++ b/lib/controller/fed_collections.go
@@ -0,0 +1,312 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/md5"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/arvados"
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+ "git.curoverse.com/arvados.git/sdk/go/keepclient"
+)
+
+type collectionFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+}
+
+func rewriteSignatures(clusterID string, expectHash string,
+ resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode != 200 {
+ return resp, nil
+ }
+
+ originalBody := resp.Body
+ defer originalBody.Close()
+
+ var col arvados.Collection
+ err = json.NewDecoder(resp.Body).Decode(&col)
+ if err != nil {
+ return nil, err
+ }
+
+ // rewriting signatures will make manifest text 5-10% bigger so calculate
+ // capacity accordingly
+ updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
+
+ hasher := md5.New()
+ mw := io.MultiWriter(hasher, updatedManifest)
+ sz := 0
+
+ scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
+ scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
+ for scanner.Scan() {
+ line := scanner.Text()
+ tokens := strings.Split(line, " ")
+ if len(tokens) < 3 {
+ return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
+ }
+
+ n, err := mw.Write([]byte(tokens[0]))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ for _, token := range tokens[1:] {
+ n, err = mw.Write([]byte(" "))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+
+ m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
+ if m != nil {
+ // Rewrite the block signature to be a remote signature
+ _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+
+ // for hash checking, ignore signatures
+ n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ } else {
+ n, err = mw.Write([]byte(token))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+ }
+ n, err = mw.Write([]byte("\n"))
+ if err != nil {
+ return nil, fmt.Errorf("Error updating manifest: %v", err)
+ }
+ sz += n
+ }
+
+ // Check that expected hash is consistent with
+ // portable_data_hash field of the returned record
+ if expectHash == "" {
+ expectHash = col.PortableDataHash
+ } else if expectHash != col.PortableDataHash {
+ return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
+ }
+
+ // Certify that the computed hash of the manifest_text matches our expectation
+ sum := hasher.Sum(nil)
+ computedHash := fmt.Sprintf("%x+%v", sum, sz)
+ if computedHash != expectHash {
+ return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
+ }
+
+ col.ManifestText = updatedManifest.String()
+
+ newbody, err := json.Marshal(col)
+ if err != nil {
+ return nil, err
+ }
+
+ buf := bytes.NewBuffer(newbody)
+ resp.Body = ioutil.NopCloser(buf)
+ resp.ContentLength = int64(buf.Len())
+ resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
+
+ return resp, nil
+}
+
+func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ return resp, requestError
+ }
+
+ if resp.StatusCode == 404 {
+ // Suppress returning this result, because we want to
+ // search the federation.
+ return nil, nil
+ }
+ return resp, nil
+}
+
+type searchRemoteClusterForPDH struct {
+ pdh string
+ remoteID string
+ mtx *sync.Mutex
+ sentResponse *bool
+ sharedContext *context.Context
+ cancelFunc func()
+ errors *[]string
+ statusCode *int
+}
+
+func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
+ s.mtx.Lock()
+ defer s.mtx.Unlock()
+
+ if *s.sentResponse {
+ // Another request already returned a response
+ return nil, nil
+ }
+
+ if requestError != nil {
+ *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
+ // Record the error and suppress response
+ return nil, nil
+ }
+
+ if resp.StatusCode != 200 {
+ // Suppress returning unsuccessful result. Maybe
+ // another request will find it.
+ // TODO collect and return error responses.
+ *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
+ if resp.StatusCode != 404 {
+ // Got a non-404 error response, convert into BadGateway
+ *s.statusCode = http.StatusBadGateway
+ }
+ return nil, nil
+ }
+
+ s.mtx.Unlock()
+
+ // This reads the response body. We don't want to hold the
+ // lock while doing this because other remote requests could
+ // also have made it to this point, and we don't want a
+ // slow response holding the lock to block a faster response
+ // that is waiting on the lock.
+ newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
+
+ s.mtx.Lock()
+
+ if *s.sentResponse {
+ // Another request already returned a response
+ return nil, nil
+ }
+
+ if err != nil {
+ // Suppress returning unsuccessful result. Maybe
+ // another request will be successful.
+ *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
+ return nil, nil
+ }
+
+ // We have a successful response. Suppress/cancel all the
+ // other requests/responses.
+ *s.sentResponse = true
+ s.cancelFunc()
+
+ return newResponse, nil
+}
+
+func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "GET" {
+ // Only handle GET requests right now
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
+ if len(m) != 2 {
+ // Not a collection PDH GET request
+ m = collectionRe.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 {
+ clusterId = m[2]
+ }
+
+ if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
+ // request for remote collection by uuid
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+ h.handler.proxy.ForwardResponse(w, newResponse, err)
+ return
+ }
+ // not a collection UUID request, or it is a request
+ // for a local UUID, either way, continue down the
+ // handler stack.
+ h.next.ServeHTTP(w, req)
+ return
+ }
+
+ // Request for collection by PDH. Search the federation.
+
+ // First, query the local cluster.
+ resp, err := h.handler.localClusterRequest(req)
+ newResp, err := filterLocalClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ return
+ }
+
+ sharedContext, cancelFunc := context.WithCancel(req.Context())
+ defer cancelFunc()
+ req = req.WithContext(sharedContext)
+
+ // Create a goroutine for each cluster in the
+ // RemoteClusters map. The first valid result gets
+ // returned to the client. When that happens, all
+ // other outstanding requests are cancelled or
+ // suppressed.
+ sentResponse := false
+ mtx := sync.Mutex{}
+ wg := sync.WaitGroup{}
+ var errors []string
+ var errorCode int = 404
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+ defer close(sem)
+ for remoteID := range h.handler.Cluster.RemoteClusters {
+ if remoteID == h.handler.Cluster.ClusterID {
+ // No need to query local cluster again
+ continue
+ }
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ if sentResponse {
+ break
+ }
+ search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
+ &sharedContext, cancelFunc, &errors, &errorCode}
+ wg.Add(1)
+ go func() {
+ resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ newResp, err := search.filterRemoteClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ }
+ wg.Done()
+ <-sem
+ }()
+ }
+ wg.Wait()
+
+ if sentResponse {
+ return
+ }
+
+ // No successful responses, so return the error
+ httpserver.Errors(w, errors, errorCode)
+}
diff --git a/lib/controller/fed_generic.go b/lib/controller/fed_generic.go
new file mode 100644
index 000000000..0630217b6
--- /dev/null
+++ b/lib/controller/fed_generic.go
@@ -0,0 +1,331 @@
+// Copyright (C) The Arvados Authors. All rights reserved.
+//
+// SPDX-License-Identifier: AGPL-3.0
+
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "regexp"
+ "sync"
+
+ "git.curoverse.com/arvados.git/sdk/go/httpserver"
+)
+
+type genericFederatedRequestHandler struct {
+ next http.Handler
+ handler *Handler
+ matcher *regexp.Regexp
+}
+
+func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
+ req *http.Request,
+ clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
+
+ found := make(map[string]bool)
+ prev_len_uuids := len(uuids) + 1
+ // Loop while
+ // (1) there are more uuids to query
+ // (2) we're making progress - on each iteration the set of
+ // uuids we are expecting for must shrink.
+ for len(uuids) > 0 && len(uuids) < prev_len_uuids {
+ var remoteReq http.Request
+ remoteReq.Header = req.Header
+ remoteReq.Method = "POST"
+ remoteReq.URL = &url.URL{Path: req.URL.Path}
+ remoteParams := make(url.Values)
+ remoteParams.Set("_method", "GET")
+ remoteParams.Set("count", "none")
+ if req.Form.Get("select") != "" {
+ remoteParams.Set("select", req.Form.Get("select"))
+ }
+ content, err := json.Marshal(uuids)
+ if err != nil {
+ return nil, "", err
+ }
+ remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
+ enc := remoteParams.Encode()
+ remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
+
+ rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+
+ var resp *http.Response
+ if clusterID == h.handler.Cluster.ClusterID {
+ resp, err = h.handler.localClusterRequest(&remoteReq)
+ } else {
+ resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
+ }
+ rc.collectResponse(resp, err)
+
+ if rc.error != nil {
+ return nil, "", rc.error
+ }
+
+ kind = rc.kind
+
+ if len(rc.responses) == 0 {
+ // We got zero responses, no point in doing
+ // another query.
+ return rp, kind, nil
+ }
+
+ rp = append(rp, rc.responses...)
+
+ // Go through the responses and determine what was
+ // returned. If there are remaining items, loop
+ // around and do another request with just the
+ // stragglers.
+ for _, i := range rc.responses {
+ uuid, ok := i["uuid"].(string)
+ if ok {
+ found[uuid] = true
+ }
+ }
+
+ l := []string{}
+ for _, u := range uuids {
+ if !found[u] {
+ l = append(l, u)
+ }
+ }
+ prev_len_uuids = len(uuids)
+ uuids = l
+ }
+
+ return rp, kind, nil
+}
+
+func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
+ req *http.Request, clusterId *string) bool {
+
+ var filters [][]interface{}
+ err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ // Split the list of uuids by prefix
+ queryClusters := make(map[string][]string)
+ expectCount := 0
+ for _, filter := range filters {
+ if len(filter) != 3 {
+ return false
+ }
+
+ if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
+ return false
+ }
+
+ op, ok := filter[1].(string)
+ if !ok {
+ return false
+ }
+
+ if op == "in" {
+ if rhs, ok := filter[2].([]interface{}); ok {
+ for _, i := range rhs {
+ if u, ok := i.(string); ok {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ }
+ }
+ } else if op == "=" {
+ if u, ok := filter[2].(string); ok {
+ *clusterId = u[0:5]
+ queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
+ expectCount += 1
+ }
+ } else {
+ return false
+ }
+
+ }
+
+ if len(queryClusters) <= 1 {
+ // Query does not search for uuids across multiple
+ // clusters.
+ return false
+ }
+
+ // Validations
+ count := req.Form.Get("count")
+ if count != "" && count != `none` && count != `"none"` {
+ httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
+ httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
+ return true
+ }
+ if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
+ httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
+ expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
+ return true
+ }
+ if req.Form.Get("select") != "" {
+ foundUUID := false
+ var selects []string
+ err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
+ if err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return true
+ }
+
+ for _, r := range selects {
+ if r == "uuid" {
+ foundUUID = true
+ break
+ }
+ }
+ if !foundUUID {
+ httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
+ return true
+ }
+ }
+
+ // Perform concurrent requests to each cluster
+
+ // use channel as a semaphore to limit the number of concurrent
+ // requests at a time
+ sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
+ defer close(sem)
+ wg := sync.WaitGroup{}
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ mtx := sync.Mutex{}
+ errors := []error{}
+ var completeResponses []map[string]interface{}
+ var kind string
+
+ for k, v := range queryClusters {
+ if len(v) == 0 {
+ // Nothing to query
+ continue
+ }
+
+ // blocks until it can put a value into the
+ // channel (which has a max queue capacity)
+ sem <- true
+ wg.Add(1)
+ go func(k string, v []string) {
+ rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
+ mtx.Lock()
+ if err == nil {
+ completeResponses = append(completeResponses, rp...)
+ kind = kn
+ } else {
+ errors = append(errors, err)
+ }
+ mtx.Unlock()
+ wg.Done()
+ <-sem
+ }(k, v)
+ }
+ wg.Wait()
+
+ if len(errors) > 0 {
+ var strerr []string
+ for _, e := range errors {
+ strerr = append(strerr, e.Error())
+ }
+ httpserver.Errors(w, strerr, http.StatusBadGateway)
+ return true
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ itemList := make(map[string]interface{})
+ itemList["items"] = completeResponses
+ itemList["kind"] = kind
+ json.NewEncoder(w).Encode(itemList)
+
+ return true
+}
+
+func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ m := h.matcher.FindStringSubmatch(req.URL.Path)
+ clusterId := ""
+
+ if len(m) > 0 && m[2] != "" {
+ clusterId = m[2]
+ }
+
+ // Get form parameters from URL and form body (if POST).
+ if err := loadParamsFromForm(req); err != nil {
+ httpserver.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Check if the parameters have an explicit cluster_id
+ if req.Form.Get("cluster_id") != "" {
+ clusterId = req.Form.Get("cluster_id")
+ }
+
+ // Handle the POST-as-GET special case (workaround for large
+ // GET requests that potentially exceed maximum URL length,
+ // like multi-object queries where the filter has 100s of
+ // items)
+ effectiveMethod := req.Method
+ if req.Method == "POST" && req.Form.Get("_method") != "" {
+ effectiveMethod = req.Form.Get("_method")
+ }
+
+ if effectiveMethod == "GET" &&
+ clusterId == "" &&
+ req.Form.Get("filters") != "" &&
+ h.handleMultiClusterQuery(w, req, &clusterId) {
+ return
+ }
+
+ if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
+ h.next.ServeHTTP(w, req)
+ } else {
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
+ }
+}
+
+type multiClusterQueryResponseCollector struct {
+ responses []map[string]interface{}
+ error error
+ kind string
+ clusterID string
+}
+
+func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
+ requestError error) (newResponse *http.Response, err error) {
+ if requestError != nil {
+ c.error = requestError
+ return nil, nil
+ }
+
+ defer resp.Body.Close()
+ var loadInto struct {
+ Kind string `json:"kind"`
+ Items []map[string]interface{} `json:"items"`
+ Errors []string `json:"errors"`
+ }
+ err = json.NewDecoder(resp.Body).Decode(&loadInto)
+
+ if err != nil {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
+ return nil, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
+ return nil, nil
+ }
+
+ c.responses = loadInto.Items
+ c.kind = loadInto.Kind
+
+ return nil, nil
+}
diff --git a/lib/controller/federation.go b/lib/controller/federation.go
index c5089fa23..03d2f3fab 100644
--- a/lib/controller/federation.go
+++ b/lib/controller/federation.go
@@ -5,12 +5,8 @@
package controller
import (
- "bufio"
"bytes"
- "context"
- "crypto/md5"
"database/sql"
- "encoding/json"
"fmt"
"io"
"io/ioutil"
@@ -18,12 +14,9 @@ import (
"net/url"
"regexp"
"strings"
- "sync"
"git.curoverse.com/arvados.git/sdk/go/arvados"
"git.curoverse.com/arvados.git/sdk/go/auth"
- "git.curoverse.com/arvados.git/sdk/go/httpserver"
- "git.curoverse.com/arvados.git/sdk/go/keepclient"
)
var pathPattern = `^/arvados/v1/%s(/([0-9a-z]{5})-%s-[0-9a-z]{15})?(.*)$`
@@ -33,17 +26,6 @@ var containerRequestsRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "container
var collectionRe = regexp.MustCompile(fmt.Sprintf(pathPattern, "collections", "4zz18"))
var collectionByPDHRe = regexp.MustCompile(`^/arvados/v1/collections/([0-9a-fA-F]{32}\+[0-9]+)+$`)
-type genericFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
- matcher *regexp.Regexp
-}
-
-type collectionFederatedRequestHandler struct {
- next http.Handler
- handler *Handler
-}
-
func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
@@ -98,597 +80,6 @@ func loadParamsFromForm(req *http.Request) error {
return nil
}
-type multiClusterQueryResponseCollector struct {
- responses []map[string]interface{}
- error error
- kind string
- clusterID string
-}
-
-func (c *multiClusterQueryResponseCollector) collectResponse(resp *http.Response,
- requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- c.error = requestError
- return nil, nil
- }
-
- defer resp.Body.Close()
- var loadInto struct {
- Kind string `json:"kind"`
- Items []map[string]interface{} `json:"items"`
- Errors []string `json:"errors"`
- }
- err = json.NewDecoder(resp.Body).Decode(&loadInto)
-
- if err != nil {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, err)
- return nil, nil
- }
- if resp.StatusCode != http.StatusOK {
- c.error = fmt.Errorf("error fetching from %v (%v): %v", c.clusterID, resp.Status, loadInto.Errors)
- return nil, nil
- }
-
- c.responses = loadInto.Items
- c.kind = loadInto.Kind
-
- return nil, nil
-}
-
-func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
- req *http.Request,
- clusterID string, uuids []string) (rp []map[string]interface{}, kind string, err error) {
-
- found := make(map[string]bool)
- prev_len_uuids := len(uuids) + 1
- // Loop while
- // (1) there are more uuids to query
- // (2) we're making progress - on each iteration the set of
- // uuids we are expecting for must shrink.
- for len(uuids) > 0 && len(uuids) < prev_len_uuids {
- var remoteReq http.Request
- remoteReq.Header = req.Header
- remoteReq.Method = "POST"
- remoteReq.URL = &url.URL{Path: req.URL.Path}
- remoteParams := make(url.Values)
- remoteParams.Set("_method", "GET")
- remoteParams.Set("count", "none")
- if req.Form.Get("select") != "" {
- remoteParams.Set("select", req.Form.Get("select"))
- }
- content, err := json.Marshal(uuids)
- if err != nil {
- return nil, "", err
- }
- remoteParams["filters"] = []string{fmt.Sprintf(`[["uuid", "in", %s]]`, content)}
- enc := remoteParams.Encode()
- remoteReq.Body = ioutil.NopCloser(bytes.NewBufferString(enc))
-
- rc := multiClusterQueryResponseCollector{clusterID: clusterID}
-
- var resp *http.Response
- if clusterID == h.handler.Cluster.ClusterID {
- resp, err = h.handler.localClusterRequest(&remoteReq)
- } else {
- resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
- }
- rc.collectResponse(resp, err)
-
- if rc.error != nil {
- return nil, "", rc.error
- }
-
- kind = rc.kind
-
- if len(rc.responses) == 0 {
- // We got zero responses, no point in doing
- // another query.
- return rp, kind, nil
- }
-
- rp = append(rp, rc.responses...)
-
- // Go through the responses and determine what was
- // returned. If there are remaining items, loop
- // around and do another request with just the
- // stragglers.
- for _, i := range rc.responses {
- uuid, ok := i["uuid"].(string)
- if ok {
- found[uuid] = true
- }
- }
-
- l := []string{}
- for _, u := range uuids {
- if !found[u] {
- l = append(l, u)
- }
- }
- prev_len_uuids = len(uuids)
- uuids = l
- }
-
- return rp, kind, nil
-}
-
-func (h *genericFederatedRequestHandler) handleMultiClusterQuery(w http.ResponseWriter,
- req *http.Request, clusterId *string) bool {
-
- var filters [][]interface{}
- err := json.Unmarshal([]byte(req.Form.Get("filters")), &filters)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- // Split the list of uuids by prefix
- queryClusters := make(map[string][]string)
- expectCount := 0
- for _, filter := range filters {
- if len(filter) != 3 {
- return false
- }
-
- if lhs, ok := filter[0].(string); !ok || lhs != "uuid" {
- return false
- }
-
- op, ok := filter[1].(string)
- if !ok {
- return false
- }
-
- if op == "in" {
- if rhs, ok := filter[2].([]interface{}); ok {
- for _, i := range rhs {
- if u, ok := i.(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- }
- }
- } else if op == "=" {
- if u, ok := filter[2].(string); ok {
- *clusterId = u[0:5]
- queryClusters[u[0:5]] = append(queryClusters[u[0:5]], u)
- expectCount += 1
- }
- } else {
- return false
- }
-
- }
-
- if len(queryClusters) <= 1 {
- // Query does not search for uuids across multiple
- // clusters.
- return false
- }
-
- // Validations
- count := req.Form.Get("count")
- if count != "" && count != `none` && count != `"none"` {
- httpserver.Error(w, "Federated multi-object query must have 'count=none'", http.StatusBadRequest)
- return true
- }
- if req.Form.Get("limit") != "" || req.Form.Get("offset") != "" || req.Form.Get("order") != "" {
- httpserver.Error(w, "Federated multi-object may not provide 'limit', 'offset' or 'order'.", http.StatusBadRequest)
- return true
- }
- if expectCount > h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse() {
- httpserver.Error(w, fmt.Sprintf("Federated multi-object request for %v objects which is more than max page size %v.",
- expectCount, h.handler.Cluster.RequestLimits.GetMaxItemsPerResponse()), http.StatusBadRequest)
- return true
- }
- if req.Form.Get("select") != "" {
- foundUUID := false
- var selects []string
- err := json.Unmarshal([]byte(req.Form.Get("select")), &selects)
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return true
- }
-
- for _, r := range selects {
- if r == "uuid" {
- foundUUID = true
- break
- }
- }
- if !foundUUID {
- httpserver.Error(w, "Federated multi-object request must include 'uuid' in 'select'", http.StatusBadRequest)
- return true
- }
- }
-
- // Perform concurrent requests to each cluster
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- wg := sync.WaitGroup{}
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- mtx := sync.Mutex{}
- errors := []error{}
- var completeResponses []map[string]interface{}
- var kind string
-
- for k, v := range queryClusters {
- if len(v) == 0 {
- // Nothing to query
- continue
- }
-
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- wg.Add(1)
- go func(k string, v []string) {
- rp, kn, err := h.remoteQueryUUIDs(w, req, k, v)
- mtx.Lock()
- if err == nil {
- completeResponses = append(completeResponses, rp...)
- kind = kn
- } else {
- errors = append(errors, err)
- }
- mtx.Unlock()
- wg.Done()
- <-sem
- }(k, v)
- }
- wg.Wait()
-
- if len(errors) > 0 {
- var strerr []string
- for _, e := range errors {
- strerr = append(strerr, e.Error())
- }
- httpserver.Errors(w, strerr, http.StatusBadGateway)
- return true
- }
-
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- itemList := make(map[string]interface{})
- itemList["items"] = completeResponses
- itemList["kind"] = kind
- json.NewEncoder(w).Encode(itemList)
-
- return true
-}
-
-func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- m := h.matcher.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 && m[2] != "" {
- clusterId = m[2]
- }
-
- // Get form parameters from URL and form body (if POST).
- if err := loadParamsFromForm(req); err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- // Check if the parameters have an explicit cluster_id
- if req.Form.Get("cluster_id") != "" {
- clusterId = req.Form.Get("cluster_id")
- }
-
- // Handle the POST-as-GET special case (workaround for large
- // GET requests that potentially exceed maximum URL length,
- // like multi-object queries where the filter has 100s of
- // items)
- effectiveMethod := req.Method
- if req.Method == "POST" && req.Form.Get("_method") != "" {
- effectiveMethod = req.Form.Get("_method")
- }
-
- if effectiveMethod == "GET" &&
- clusterId == "" &&
- req.Form.Get("filters") != "" &&
- h.handleMultiClusterQuery(w, req, &clusterId) {
- return
- }
-
- if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
- h.next.ServeHTTP(w, req)
- } else {
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
- h.handler.proxy.ForwardResponse(w, resp, err)
- }
-}
-
-func rewriteSignatures(clusterID string, expectHash string,
- resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode != 200 {
- return resp, nil
- }
-
- originalBody := resp.Body
- defer originalBody.Close()
-
- var col arvados.Collection
- err = json.NewDecoder(resp.Body).Decode(&col)
- if err != nil {
- return nil, err
- }
-
- // rewriting signatures will make manifest text 5-10% bigger so calculate
- // capacity accordingly
- updatedManifest := bytes.NewBuffer(make([]byte, 0, int(float64(len(col.ManifestText))*1.1)))
-
- hasher := md5.New()
- mw := io.MultiWriter(hasher, updatedManifest)
- sz := 0
-
- scanner := bufio.NewScanner(strings.NewReader(col.ManifestText))
- scanner.Buffer(make([]byte, 1048576), len(col.ManifestText))
- for scanner.Scan() {
- line := scanner.Text()
- tokens := strings.Split(line, " ")
- if len(tokens) < 3 {
- return nil, fmt.Errorf("Invalid stream (<3 tokens): %q", line)
- }
-
- n, err := mw.Write([]byte(tokens[0]))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- for _, token := range tokens[1:] {
- n, err = mw.Write([]byte(" "))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
-
- m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
- if m != nil {
- // Rewrite the block signature to be a remote signature
- _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
-
- // for hash checking, ignore signatures
- n, err = fmt.Fprintf(hasher, "%s%s", m[1], m[2])
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- } else {
- n, err = mw.Write([]byte(token))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
- }
- n, err = mw.Write([]byte("\n"))
- if err != nil {
- return nil, fmt.Errorf("Error updating manifest: %v", err)
- }
- sz += n
- }
-
- // Check that expected hash is consistent with
- // portable_data_hash field of the returned record
- if expectHash == "" {
- expectHash = col.PortableDataHash
- } else if expectHash != col.PortableDataHash {
- return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
- }
-
- // Certify that the computed hash of the manifest_text matches our expectation
- sum := hasher.Sum(nil)
- computedHash := fmt.Sprintf("%x+%v", sum, sz)
- if computedHash != expectHash {
- return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
- }
-
- col.ManifestText = updatedManifest.String()
-
- newbody, err := json.Marshal(col)
- if err != nil {
- return nil, err
- }
-
- buf := bytes.NewBuffer(newbody)
- resp.Body = ioutil.NopCloser(buf)
- resp.ContentLength = int64(buf.Len())
- resp.Header.Set("Content-Length", fmt.Sprintf("%v", buf.Len()))
-
- return resp, nil
-}
-
-func filterLocalClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- if requestError != nil {
- return resp, requestError
- }
-
- if resp.StatusCode == 404 {
- // Suppress returning this result, because we want to
- // search the federation.
- return nil, nil
- }
- return resp, nil
-}
-
-type searchRemoteClusterForPDH struct {
- pdh string
- remoteID string
- mtx *sync.Mutex
- sentResponse *bool
- sharedContext *context.Context
- cancelFunc func()
- errors *[]string
- statusCode *int
-}
-
-func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if requestError != nil {
- *s.errors = append(*s.errors, fmt.Sprintf("Request error contacting %q: %v", s.remoteID, requestError))
- // Record the error and suppress response
- return nil, nil
- }
-
- if resp.StatusCode != 200 {
- // Suppress returning unsuccessful result. Maybe
- // another request will find it.
- // TODO collect and return error responses.
- *s.errors = append(*s.errors, fmt.Sprintf("Response from %q: %v", s.remoteID, resp.Status))
- if resp.StatusCode != 404 {
- // Got a non-404 error response, convert into BadGateway
- *s.statusCode = http.StatusBadGateway
- }
- return nil, nil
- }
-
- s.mtx.Unlock()
-
- // This reads the response body. We don't want to hold the
- // lock while doing this because other remote requests could
- // also have made it to this point, and we don't want a
- // slow response holding the lock to block a faster response
- // that is waiting on the lock.
- newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
-
- s.mtx.Lock()
-
- if *s.sentResponse {
- // Another request already returned a response
- return nil, nil
- }
-
- if err != nil {
- // Suppress returning unsuccessful result. Maybe
- // another request will be successful.
- *s.errors = append(*s.errors, fmt.Sprintf("Error parsing response from %q: %v", s.remoteID, err))
- return nil, nil
- }
-
- // We have a successful response. Suppress/cancel all the
- // other requests/responses.
- *s.sentResponse = true
- s.cancelFunc()
-
- return newResponse, nil
-}
-
-func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- if req.Method != "GET" {
- // Only handle GET requests right now
- h.next.ServeHTTP(w, req)
- return
- }
-
- m := collectionByPDHRe.FindStringSubmatch(req.URL.Path)
- if len(m) != 2 {
- // Not a collection PDH GET request
- m = collectionRe.FindStringSubmatch(req.URL.Path)
- clusterId := ""
-
- if len(m) > 0 {
- clusterId = m[2]
- }
-
- if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
- // request for remote collection by uuid
- resp, err := h.handler.remoteClusterRequest(clusterId, req)
- newResponse, err := rewriteSignatures(clusterId, "", resp, err)
- h.handler.proxy.ForwardResponse(w, newResponse, err)
- return
- }
- // not a collection UUID request, or it is a request
- // for a local UUID, either way, continue down the
- // handler stack.
- h.next.ServeHTTP(w, req)
- return
- }
-
- // Request for collection by PDH. Search the federation.
-
- // First, query the local cluster.
- resp, err := h.handler.localClusterRequest(req)
- newResp, err := filterLocalClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
- return
- }
-
- sharedContext, cancelFunc := context.WithCancel(req.Context())
- defer cancelFunc()
- req = req.WithContext(sharedContext)
-
- // Create a goroutine for each cluster in the
- // RemoteClusters map. The first valid result gets
- // returned to the client. When that happens, all
- // other outstanding requests are cancelled or
- // suppressed.
- sentResponse := false
- mtx := sync.Mutex{}
- wg := sync.WaitGroup{}
- var errors []string
- var errorCode int = 404
-
- // use channel as a semaphore to limit the number of concurrent
- // requests at a time
- sem := make(chan bool, h.handler.Cluster.RequestLimits.GetMultiClusterRequestConcurrency())
- defer close(sem)
- for remoteID := range h.handler.Cluster.RemoteClusters {
- if remoteID == h.handler.Cluster.ClusterID {
- // No need to query local cluster again
- continue
- }
- // blocks until it can put a value into the
- // channel (which has a max queue capacity)
- sem <- true
- if sentResponse {
- break
- }
- search := &searchRemoteClusterForPDH{m[1], remoteID, &mtx, &sentResponse,
- &sharedContext, cancelFunc, &errors, &errorCode}
- wg.Add(1)
- go func() {
- resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
- newResp, err := search.filterRemoteClusterResponse(resp, err)
- if newResp != nil || err != nil {
- h.handler.proxy.ForwardResponse(w, newResp, err)
- }
- wg.Done()
- <-sem
- }()
- }
- wg.Wait()
-
- if sentResponse {
- return
- }
-
- // No successful responses, so return the error
- httpserver.Errors(w, errors, errorCode)
-}
-
func (h *Handler) setupProxyRemoteCluster(next http.Handler) http.Handler {
mux := http.NewServeMux()
mux.Handle("/arvados/v1/workflows", &genericFederatedRequestHandler{next, h, wfRe})
commit 0c99017642e413140b5315ddae8a99a7fb44a293
Author: Peter Amstutz <pamstutz at veritasgenetics.com>
Date: Thu Oct 18 16:08:28 2018 -0400
14262: Refactoring proxy
Split proxy.Do() into ForwardRequest() and ForwardResponse().
Inversion of control eliminates need for "filter" callback, since the
caller can now modify the response in between the calls to
ForwardRequest() and ForwardResponse().
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz at veritasgenetics.com>
diff --git a/lib/controller/federation.go b/lib/controller/federation.go
index f30365574..c5089fa23 100644
--- a/lib/controller/federation.go
+++ b/lib/controller/federation.go
@@ -44,17 +44,10 @@ type collectionFederatedRequestHandler struct {
handler *Handler
}
-func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, req *http.Request, filter ResponseFilter) {
+func (h *Handler) remoteClusterRequest(remoteID string, req *http.Request) (*http.Response, error) {
remote, ok := h.Cluster.RemoteClusters[remoteID]
if !ok {
- err := fmt.Errorf("no proxy available for cluster %v", remoteID)
- if filter != nil {
- _, err = filter(nil, err)
- }
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusNotFound)
- }
- return
+ return nil, HTTPError{fmt.Sprintf("no proxy available for cluster %v", remoteID), http.StatusNotFound}
}
scheme := remote.Scheme
if scheme == "" {
@@ -62,13 +55,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
}
saltedReq, err := h.saltAuthToken(req, remoteID)
if err != nil {
- if filter != nil {
- _, err = filter(nil, err)
- }
- if err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadRequest)
- }
- return
+ return nil, err
}
urlOut := &url.URL{
Scheme: scheme,
@@ -81,7 +68,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
if remote.Insecure {
client = h.insecureClient
}
- h.proxy.Do(w, saltedReq, urlOut, client, filter)
+ return h.proxy.ForwardRequest(saltedReq, urlOut, client)
}
// Buffer request body, parse form parameters in request, and then
@@ -179,13 +166,14 @@ func (h *genericFederatedRequestHandler) remoteQueryUUIDs(w http.ResponseWriter,
rc := multiClusterQueryResponseCollector{clusterID: clusterID}
+ var resp *http.Response
if clusterID == h.handler.Cluster.ClusterID {
- h.handler.localClusterRequest(w, &remoteReq,
- rc.collectResponse)
+ resp, err = h.handler.localClusterRequest(&remoteReq)
} else {
- h.handler.remoteClusterRequest(clusterID, w, &remoteReq,
- rc.collectResponse)
+ resp, err = h.handler.remoteClusterRequest(clusterID, &remoteReq)
}
+ rc.collectResponse(resp, err)
+
if rc.error != nil {
return nil, "", rc.error
}
@@ -412,16 +400,14 @@ func (h *genericFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req *h
if clusterId == "" || clusterId == h.handler.Cluster.ClusterID {
h.next.ServeHTTP(w, req)
} else {
- h.handler.remoteClusterRequest(clusterId, w, req, nil)
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ h.handler.proxy.ForwardResponse(w, resp, err)
}
}
-type rewriteSignaturesClusterId struct {
- clusterID string
- expectHash string
-}
+func rewriteSignatures(clusterID string, expectHash string,
+ resp *http.Response, requestError error) (newResponse *http.Response, err error) {
-func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requestError error) (newResponse *http.Response, err error) {
if requestError != nil {
return resp, requestError
}
@@ -471,7 +457,7 @@ func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requ
m := keepclient.SignedLocatorRe.FindStringSubmatch(token)
if m != nil {
// Rewrite the block signature to be a remote signature
- _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], rw.clusterID, m[5][2:], m[8])
+ _, err = fmt.Fprintf(updatedManifest, "%s%s%s+R%s-%s%s", m[1], m[2], m[3], clusterID, m[5][2:], m[8])
if err != nil {
return nil, fmt.Errorf("Error updating manifest: %v", err)
}
@@ -499,17 +485,17 @@ func (rw rewriteSignaturesClusterId) rewriteSignatures(resp *http.Response, requ
// Check that expected hash is consistent with
// portable_data_hash field of the returned record
- if rw.expectHash == "" {
- rw.expectHash = col.PortableDataHash
- } else if rw.expectHash != col.PortableDataHash {
- return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", rw.expectHash, col.PortableDataHash)
+ if expectHash == "" {
+ expectHash = col.PortableDataHash
+ } else if expectHash != col.PortableDataHash {
+ return nil, fmt.Errorf("portable_data_hash %q on returned record did not match expected hash %q ", expectHash, col.PortableDataHash)
}
// Certify that the computed hash of the manifest_text matches our expectation
sum := hasher.Sum(nil)
computedHash := fmt.Sprintf("%x+%v", sum, sz)
- if computedHash != rw.expectHash {
- return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, rw.expectHash)
+ if computedHash != expectHash {
+ return nil, fmt.Errorf("Computed manifest_text hash %q did not match expected hash %q", computedHash, expectHash)
}
col.ManifestText = updatedManifest.String()
@@ -585,7 +571,7 @@ func (s *searchRemoteClusterForPDH) filterRemoteClusterResponse(resp *http.Respo
// also have made it to this point, and we don't want a
// slow response holding the lock to block a faster response
// that is waiting on the lock.
- newResponse, err = rewriteSignaturesClusterId{s.remoteID, s.pdh}.rewriteSignatures(resp, nil)
+ newResponse, err = rewriteSignatures(s.remoteID, s.pdh, resp, nil)
s.mtx.Lock()
@@ -628,8 +614,9 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
if clusterId != "" && clusterId != h.handler.Cluster.ClusterID {
// request for remote collection by uuid
- h.handler.remoteClusterRequest(clusterId, w, req,
- rewriteSignaturesClusterId{clusterId, ""}.rewriteSignatures)
+ resp, err := h.handler.remoteClusterRequest(clusterId, req)
+ newResponse, err := rewriteSignatures(clusterId, "", resp, err)
+ h.handler.proxy.ForwardResponse(w, newResponse, err)
return
}
// not a collection UUID request, or it is a request
@@ -642,7 +629,10 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
// Request for collection by PDH. Search the federation.
// First, query the local cluster.
- if h.handler.localClusterRequest(w, req, filterLocalClusterResponse) {
+ resp, err := h.handler.localClusterRequest(req)
+ newResp, err := filterLocalClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
return
}
@@ -680,7 +670,11 @@ func (h *collectionFederatedRequestHandler) ServeHTTP(w http.ResponseWriter, req
&sharedContext, cancelFunc, &errors, &errorCode}
wg.Add(1)
go func() {
- h.handler.remoteClusterRequest(search.remoteID, w, req, search.filterRemoteClusterResponse)
+ resp, err := h.handler.remoteClusterRequest(search.remoteID, req)
+ newResp, err := search.filterRemoteClusterResponse(resp, err)
+ if newResp != nil || err != nil {
+ h.handler.proxy.ForwardResponse(w, newResp, err)
+ }
wg.Done()
<-sem
}()
diff --git a/lib/controller/handler.go b/lib/controller/handler.go
index 0c31815cb..5e9012949 100644
--- a/lib/controller/handler.go
+++ b/lib/controller/handler.go
@@ -121,14 +121,10 @@ func prepend(next http.Handler, middleware middlewareFunc) http.Handler {
})
}
-// localClusterRequest sets up a request so it can be proxied to the
-// local API server using proxy.Do(). Returns true if a response was
-// written, false if not.
-func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request, filter ResponseFilter) bool {
+func (h *Handler) localClusterRequest(req *http.Request) (*http.Response, error) {
urlOut, insecure, err := findRailsAPI(h.Cluster, h.NodeProfile)
if err != nil {
- httpserver.Error(w, err.Error(), http.StatusInternalServerError)
- return true
+ return nil, err
}
urlOut = &url.URL{
Scheme: urlOut.Scheme,
@@ -141,12 +137,14 @@ func (h *Handler) localClusterRequest(w http.ResponseWriter, req *http.Request,
if insecure {
client = h.insecureClient
}
- return h.proxy.Do(w, req, urlOut, client, filter)
+ return h.proxy.ForwardRequest(req, urlOut, client)
}
func (h *Handler) proxyRailsAPI(w http.ResponseWriter, req *http.Request, next http.Handler) {
- if !h.localClusterRequest(w, req, nil) && next != nil {
- next.ServeHTTP(w, req)
+ resp, err := h.localClusterRequest(req)
+ n, err := h.proxy.ForwardResponse(w, resp, err)
+ if err != nil {
+ httpserver.Logger(req).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
}
}
diff --git a/lib/controller/proxy.go b/lib/controller/proxy.go
index 951cb9d25..9aecdc1b2 100644
--- a/lib/controller/proxy.go
+++ b/lib/controller/proxy.go
@@ -19,6 +19,15 @@ type proxy struct {
RequestTimeout time.Duration
}
+type HTTPError struct {
+ Message string
+ Code int
+}
+
+func (h HTTPError) Error() string {
+ return h.Message
+}
+
// headers that shouldn't be forwarded when proxying. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
var dropHeaders = map[string]bool{
@@ -36,15 +45,11 @@ var dropHeaders = map[string]bool{
type ResponseFilter func(*http.Response, error) (*http.Response, error)
-// Do sends a request, passes the result to the filter (if provided)
-// and then if the result is not suppressed by the filter, sends the
-// request to the ResponseWriter. Returns true if a response was written,
-// false if not.
-func (p *proxy) Do(w http.ResponseWriter,
+// Forward a request to downstream service, and return response or error.
+func (p *proxy) ForwardRequest(
reqIn *http.Request,
urlOut *url.URL,
- client *http.Client,
- filter ResponseFilter) bool {
+ client *http.Client) (*http.Response, error) {
// Copy headers from incoming request, then add/replace proxy
// headers like Via and X-Forwarded-For.
@@ -79,50 +84,26 @@ func (p *proxy) Do(w http.ResponseWriter,
Body: reqIn.Body,
}).WithContext(ctx)
- resp, err := client.Do(reqOut)
- if filter == nil && err != nil {
- httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
-
- // make sure original response body gets closed
- var originalBody io.ReadCloser
- if resp != nil {
- originalBody = resp.Body
- if originalBody != nil {
- defer originalBody.Close()
- }
- }
-
- if filter != nil {
- resp, err = filter(resp, err)
+ return client.Do(reqOut)
+}
- if err != nil {
+// Copy a response (or error) to the upstream client
+func (p *proxy) ForwardResponse(w http.ResponseWriter, resp *http.Response, err error) (int64, error) {
+ if err != nil {
+ if he, ok := err.(HTTPError); ok {
+ httpserver.Error(w, he.Message, he.Code)
+ } else {
httpserver.Error(w, err.Error(), http.StatusBadGateway)
- return true
- }
- if resp == nil {
- // filter() returned a nil response, this means suppress
- // writing a response, for the case where there might
- // be multiple response writers.
- return false
- }
-
- // the filter gave us a new response body, make sure that gets closed too.
- if resp.Body != originalBody {
- defer resp.Body.Close()
}
+ return 0, nil
}
+ defer resp.Body.Close()
for k, v := range resp.Header {
for _, v := range v {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
- n, err := io.Copy(w, resp.Body)
- if err != nil {
- httpserver.Logger(reqIn).WithError(err).WithField("bytesCopied", n).Error("error copying response body")
- }
- return true
+ return io.Copy(w, resp.Body)
}
commit 748c5e85538b145e51777fe1015b943546d9ca06
Author: Peter Amstutz <pamstutz at veritasgenetics.com>
Date: Thu Oct 18 14:34:49 2018 -0400
14262: Fix bug moving api_token to header
Arvados-DCO-1.1-Signed-off-by: Peter Amstutz <pamstutz at veritasgenetics.com>
diff --git a/lib/controller/federation.go b/lib/controller/federation.go
index e5c56bd83..f30365574 100644
--- a/lib/controller/federation.go
+++ b/lib/controller/federation.go
@@ -14,7 +14,6 @@ import (
"fmt"
"io"
"io/ioutil"
- "log"
"net/http"
"net/url"
"regexp"
@@ -61,7 +60,7 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
if scheme == "" {
scheme = "https"
}
- req, err := h.saltAuthToken(req, remoteID)
+ saltedReq, err := h.saltAuthToken(req, remoteID)
if err != nil {
if filter != nil {
_, err = filter(nil, err)
@@ -74,15 +73,15 @@ func (h *Handler) remoteClusterRequest(remoteID string, w http.ResponseWriter, r
urlOut := &url.URL{
Scheme: scheme,
Host: remote.Host,
- Path: req.URL.Path,
- RawPath: req.URL.RawPath,
- RawQuery: req.URL.RawQuery,
+ Path: saltedReq.URL.Path,
+ RawPath: saltedReq.URL.RawPath,
+ RawQuery: saltedReq.URL.RawQuery,
}
client := h.secureClient
if remote.Insecure {
client = h.insecureClient
}
- h.proxy.Do(w, req, urlOut, client, filter)
+ h.proxy.Do(w, saltedReq, urlOut, client, filter)
}
// Buffer request body, parse form parameters in request, and then
@@ -777,7 +776,6 @@ func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *h
token, err := auth.SaltToken(creds.Tokens[0], remote)
- log.Printf("Salting %q %q to get %q %q", creds.Tokens[0], remote, token, err)
if err == auth.ErrObsoleteToken {
// If the token exists in our own database, salt it
// for the remote. Otherwise, assume it was issued by
@@ -801,14 +799,11 @@ func (h *Handler) saltAuthToken(req *http.Request, remote string) (updatedReq *h
}
updatedReq.Header = http.Header{}
for k, v := range req.Header {
- if k == "Authorization" {
- updatedReq.Header[k] = []string{"Bearer " + token}
- } else {
+ if k != "Authorization" {
updatedReq.Header[k] = v
}
}
-
- log.Printf("Salted %q %q to get %q", creds.Tokens[0], remote, token)
+ updatedReq.Header.Set("Authorization", "Bearer "+token)
// Remove api_token=... from the the query string, in case we
// end up forwarding the request.
-----------------------------------------------------------------------
hooks/post-receive
--
More information about the arvados-commits
mailing list