diff --git a/google/data_source_tpu_tensorflow_versions.go b/google/data_source_tpu_tensorflow_versions.go new file mode 100644 index 00000000..0c36811b --- /dev/null +++ b/google/data_source_tpu_tensorflow_versions.go @@ -0,0 +1,82 @@ +package google + +import ( + "fmt" + "log" + "sort" + "time" + + "github.com/hashicorp/terraform/helper/schema" +) + +func dataSourceTpuTensorflowVersions() *schema.Resource { + return &schema.Resource{ + Read: dataSourceTpuTensorFlowVersionsRead, + Schema: map[string]*schema.Schema{ + "project": { + Type: schema.TypeString, + Optional: true, + Computed: true, + }, + "zone": { + Type: schema.TypeString, + Optional: true, + Computed: true, + }, + "versions": { + Type: schema.TypeList, + Computed: true, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + }, + } +} + +func dataSourceTpuTensorFlowVersionsRead(d *schema.ResourceData, meta interface{}) error { + config := meta.(*Config) + + project, err := getProject(d, config) + if err != nil { + return err + } + + zone, err := getZone(d, config) + if err != nil { + return err + } + + url, err := replaceVars(d, config, "https://tpu.googleapis.com/v1/projects/{{project}}/locations/{{zone}}/tensorflowVersions") + if err != nil { + return err + } + + versionsRaw, err := paginatedListRequest(url, config, flattenTpuTensorflowVersions) + if err != nil { + return fmt.Errorf("Error listing TPU Tensorflow versions: %s", err) + } + + versions := make([]string, len(versionsRaw)) + for i, ver := range versionsRaw { + versions[i] = ver.(string) + } + sort.Strings(versions) + + log.Printf("[DEBUG] Received Google TPU Tensorflow Versions: %q", versions) + + d.Set("versions", versions) + d.Set("zone", zone) + d.Set("project", project) + d.SetId(time.Now().UTC().String()) + + return nil +} + +func flattenTpuTensorflowVersions(resp map[string]interface{}) []interface{} { + verObjList := resp["tensorflowVersions"].([]interface{}) + versions := make([]interface{}, len(verObjList)) + for i, v := range verObjList { + verObj := v.(map[string]interface{}) + versions[i] = verObj["version"] + } + return versions +} diff --git a/google/data_source_tpu_tensorflow_versions_test.go b/google/data_source_tpu_tensorflow_versions_test.go new file mode 100644 index 00000000..176ec3e6 --- /dev/null +++ b/google/data_source_tpu_tensorflow_versions_test.go @@ -0,0 +1,72 @@ +package google + +import ( + "errors" + "fmt" + "strconv" + "testing" + + "github.com/hashicorp/terraform/helper/resource" + "github.com/hashicorp/terraform/terraform" + "regexp" +) + +func TestAccTpuTensorflowVersions_basic(t *testing.T) { + t.Parallel() + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + { + Config: testAccTpuTensorFlowVersionsConfig, + Check: resource.ComposeTestCheckFunc( + testAccCheckGoogleTpuTensorflowVersions("data.google_tpu_tensorflow_versions.available"), + ), + }, + }, + }) +} + +func testAccCheckGoogleTpuTensorflowVersions(n string) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[n] + if !ok { + return fmt.Errorf("Can't find TPU Tensorflow versions data source: %s", n) + } + + if rs.Primary.ID == "" { + return errors.New("data source ID not set.") + } + + count, ok := rs.Primary.Attributes["versions.#"] + if !ok { + return errors.New("can't find 'names' attribute") + } + + cnt, err := strconv.Atoi(count) + if err != nil { + return errors.New("failed to read number of version") + } + if cnt < 2 { + return fmt.Errorf("expected at least 2 versions, received %d, this is most likely a bug", cnt) + } + + for i := 0; i < cnt; i++ { + idx := fmt.Sprintf("versions.%d", i) + v, ok := rs.Primary.Attributes[idx] + if !ok { + return fmt.Errorf("expected %q, version not found", idx) + } + + if !regexp.MustCompile(`^([0-9]+\.)+[0-9]+$`).MatchString(v) { + return fmt.Errorf("unexpected version format for %q, value is %v", idx, v) + } + } + return nil + } +} + +var testAccTpuTensorFlowVersionsConfig = ` +data "google_tpu_tensorflow_versions" "available" {} +` diff --git a/google/provider.go b/google/provider.go index 4f5dfa02..8822889f 100644 --- a/google/provider.go +++ b/google/provider.go @@ -122,6 +122,7 @@ func Provider() terraform.ResourceProvider { "google_storage_object_signed_url": dataSourceGoogleSignedUrl(), "google_storage_project_service_account": dataSourceGoogleStorageProjectServiceAccount(), "google_storage_transfer_project_service_account": dataSourceGoogleStorageTransferProjectServiceAccount(), + "google_tpu_tensorflow_versions": dataSourceTpuTensorflowVersions(), }, ResourcesMap: ResourceMap(), diff --git a/google/resource_tpu_node_generated_test.go b/google/resource_tpu_node_generated_test.go index 61b8cfb0..f2c731d3 100644 --- a/google/resource_tpu_node_generated_test.go +++ b/google/resource_tpu_node_generated_test.go @@ -51,12 +51,14 @@ func TestAccTpuNode_tpuNodeBasicExample(t *testing.T) { func testAccTpuNode_tpuNodeBasicExample(context map[string]interface{}) string { return Nprintf(` +data "google_tpu_tensorflow_versions" "available" { } + resource "google_tpu_node" "tpu" { name = "test-tpu-%{random_suffix}" zone = "us-central1-b" accelerator_type = "v3-8" - tensorflow_version = "1.13" + tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}" cidr_block = "10.2.0.0/29" } `, context) @@ -94,6 +96,8 @@ resource "google_compute_network" "tpu_network" { auto_create_subnetworks = false } +data "google_tpu_tensorflow_versions" "available" { } + resource "google_tpu_node" "tpu" { name = "test-tpu-%{random_suffix}" zone = "us-central1-b" @@ -101,7 +105,7 @@ resource "google_tpu_node" "tpu" { accelerator_type = "v3-8" cidr_block = "10.3.0.0/29" - tensorflow_version = "1.13" + tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}" description = "Terraform Google Provider test TPU" network = "${google_compute_network.tpu_network.name}" diff --git a/google/resource_tpu_node_test.go b/google/resource_tpu_node_test.go index dd107b79..d0d15675 100644 --- a/google/resource_tpu_node_test.go +++ b/google/resource_tpu_node_test.go @@ -19,7 +19,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) { CheckDestroy: testAccCheckTpuNodeDestroy, Steps: []resource.TestStep{ { - Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.11"), + Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 0), }, { ResourceName: "google_tpu_node.tpu", @@ -28,7 +28,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) { ImportStateVerifyIgnore: []string{"zone"}, }, { - Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.12"), + Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 1), }, { ResourceName: "google_tpu_node.tpu", @@ -43,15 +43,17 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) { // WARNING: cidr_block must not overlap with other existing TPU blocks // Make sure if you change this value that it does not overlap with the // autogenerated examples. -func testAccTpuNode_tpuNodeTensorFlow(nodeId, tensorFlowVer string) string { +func testAccTpuNode_tpuNodeTensorFlow(nodeId string, versionIdx int) string { return fmt.Sprintf(` +data "google_tpu_tensorflow_versions" "available" { } + resource "google_tpu_node" "tpu" { name = "%s" zone = "us-central1-b" accelerator_type = "v3-8" - tensorflow_version = "%s" + tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[%d]}" cidr_block = "10.1.0.0/29" } -`, nodeId, tensorFlowVer) +`, nodeId, versionIdx) } diff --git a/google/utils.go b/google/utils.go index 4731da7f..45a17db5 100644 --- a/google/utils.go +++ b/google/utils.go @@ -423,3 +423,27 @@ func serviceAccountFQN(serviceAccount string, d TerraformResourceData, config *C return fmt.Sprintf("projects/-/serviceAccounts/%s@%s.iam.gserviceaccount.com", serviceAccount, project), nil } + +func paginatedListRequest(baseUrl string, config *Config, flattener func(map[string]interface{}) []interface{}) ([]interface{}, error) { + res, err := sendRequest(config, "GET", baseUrl, nil) + if err != nil { + return nil, err + } + + ls := flattener(res) + pageToken, ok := res["pageToken"] + for ok { + if pageToken.(string) == "" { + break + } + url := fmt.Sprintf("%s?pageToken=%s", baseUrl, pageToken.(string)) + res, err = sendRequest(config, "GET", url, nil) + if err != nil { + return nil, err + } + ls = append(ls, flattener(res)) + pageToken, ok = res["pageToken"] + } + + return ls, nil +} diff --git a/website/docs/r/tpu_node.html.markdown b/website/docs/r/tpu_node.html.markdown index a96cce95..2059607a 100644 --- a/website/docs/r/tpu_node.html.markdown +++ b/website/docs/r/tpu_node.html.markdown @@ -39,12 +39,14 @@ To get more information about Node, see: ```hcl +data "google_tpu_tensorflow_versions" "available" { } + resource "google_tpu_node" "tpu" { name = "test-tpu" zone = "us-central1-b" accelerator_type = "v3-8" - tensorflow_version = "1.13" + tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}" cidr_block = "10.2.0.0/29" } ``` @@ -62,6 +64,8 @@ resource "google_compute_network" "tpu_network" { auto_create_subnetworks = false } +data "google_tpu_tensorflow_versions" "available" { } + resource "google_tpu_node" "tpu" { name = "test-tpu" zone = "us-central1-b" @@ -69,7 +73,7 @@ resource "google_tpu_node" "tpu" { accelerator_type = "v3-8" cidr_block = "10.3.0.0/29" - tensorflow_version = "1.13" + tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}" description = "Terraform Google Provider test TPU" network = "${google_compute_network.tpu_network.name}"