Add TPU Tensorflow datasource (#3325)

Signed-off-by: Modular Magician <magic-modules@google.com>
This commit is contained in:
The Magician 2019-03-27 16:26:07 -07:00 committed by emily
parent 0fccddf4d2
commit 59b08fb422
7 changed files with 198 additions and 9 deletions

View File

@ -0,0 +1,82 @@
package google
import (
"fmt"
"log"
"sort"
"time"
"github.com/hashicorp/terraform/helper/schema"
)
func dataSourceTpuTensorflowVersions() *schema.Resource {
return &schema.Resource{
Read: dataSourceTpuTensorFlowVersionsRead,
Schema: map[string]*schema.Schema{
"project": {
Type: schema.TypeString,
Optional: true,
Computed: true,
},
"zone": {
Type: schema.TypeString,
Optional: true,
Computed: true,
},
"versions": {
Type: schema.TypeList,
Computed: true,
Elem: &schema.Schema{Type: schema.TypeString},
},
},
}
}
func dataSourceTpuTensorFlowVersionsRead(d *schema.ResourceData, meta interface{}) error {
config := meta.(*Config)
project, err := getProject(d, config)
if err != nil {
return err
}
zone, err := getZone(d, config)
if err != nil {
return err
}
url, err := replaceVars(d, config, "https://tpu.googleapis.com/v1/projects/{{project}}/locations/{{zone}}/tensorflowVersions")
if err != nil {
return err
}
versionsRaw, err := paginatedListRequest(url, config, flattenTpuTensorflowVersions)
if err != nil {
return fmt.Errorf("Error listing TPU Tensorflow versions: %s", err)
}
versions := make([]string, len(versionsRaw))
for i, ver := range versionsRaw {
versions[i] = ver.(string)
}
sort.Strings(versions)
log.Printf("[DEBUG] Received Google TPU Tensorflow Versions: %q", versions)
d.Set("versions", versions)
d.Set("zone", zone)
d.Set("project", project)
d.SetId(time.Now().UTC().String())
return nil
}
func flattenTpuTensorflowVersions(resp map[string]interface{}) []interface{} {
verObjList := resp["tensorflowVersions"].([]interface{})
versions := make([]interface{}, len(verObjList))
for i, v := range verObjList {
verObj := v.(map[string]interface{})
versions[i] = verObj["version"]
}
return versions
}

View File

@ -0,0 +1,72 @@
package google
import (
"errors"
"fmt"
"strconv"
"testing"
"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/terraform"
"regexp"
)
func TestAccTpuTensorflowVersions_basic(t *testing.T) {
t.Parallel()
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: testAccTpuTensorFlowVersionsConfig,
Check: resource.ComposeTestCheckFunc(
testAccCheckGoogleTpuTensorflowVersions("data.google_tpu_tensorflow_versions.available"),
),
},
},
})
}
func testAccCheckGoogleTpuTensorflowVersions(n string) resource.TestCheckFunc {
return func(s *terraform.State) error {
rs, ok := s.RootModule().Resources[n]
if !ok {
return fmt.Errorf("Can't find TPU Tensorflow versions data source: %s", n)
}
if rs.Primary.ID == "" {
return errors.New("data source ID not set.")
}
count, ok := rs.Primary.Attributes["versions.#"]
if !ok {
return errors.New("can't find 'names' attribute")
}
cnt, err := strconv.Atoi(count)
if err != nil {
return errors.New("failed to read number of version")
}
if cnt < 2 {
return fmt.Errorf("expected at least 2 versions, received %d, this is most likely a bug", cnt)
}
for i := 0; i < cnt; i++ {
idx := fmt.Sprintf("versions.%d", i)
v, ok := rs.Primary.Attributes[idx]
if !ok {
return fmt.Errorf("expected %q, version not found", idx)
}
if !regexp.MustCompile(`^([0-9]+\.)+[0-9]+$`).MatchString(v) {
return fmt.Errorf("unexpected version format for %q, value is %v", idx, v)
}
}
return nil
}
}
var testAccTpuTensorFlowVersionsConfig = `
data "google_tpu_tensorflow_versions" "available" {}
`

View File

@ -122,6 +122,7 @@ func Provider() terraform.ResourceProvider {
"google_storage_object_signed_url": dataSourceGoogleSignedUrl(),
"google_storage_project_service_account": dataSourceGoogleStorageProjectServiceAccount(),
"google_storage_transfer_project_service_account": dataSourceGoogleStorageTransferProjectServiceAccount(),
"google_tpu_tensorflow_versions": dataSourceTpuTensorflowVersions(),
},
ResourcesMap: ResourceMap(),

View File

@ -51,12 +51,14 @@ func TestAccTpuNode_tpuNodeBasicExample(t *testing.T) {
func testAccTpuNode_tpuNodeBasicExample(context map[string]interface{}) string {
return Nprintf(`
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu-%{random_suffix}"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
cidr_block = "10.2.0.0/29"
}
`, context)
@ -94,6 +96,8 @@ resource "google_compute_network" "tpu_network" {
auto_create_subnetworks = false
}
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu-%{random_suffix}"
zone = "us-central1-b"
@ -101,7 +105,7 @@ resource "google_tpu_node" "tpu" {
accelerator_type = "v3-8"
cidr_block = "10.3.0.0/29"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
description = "Terraform Google Provider test TPU"
network = "${google_compute_network.tpu_network.name}"

View File

@ -19,7 +19,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
CheckDestroy: testAccCheckTpuNodeDestroy,
Steps: []resource.TestStep{
{
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.11"),
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 0),
},
{
ResourceName: "google_tpu_node.tpu",
@ -28,7 +28,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
ImportStateVerifyIgnore: []string{"zone"},
},
{
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.12"),
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 1),
},
{
ResourceName: "google_tpu_node.tpu",
@ -43,15 +43,17 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
// WARNING: cidr_block must not overlap with other existing TPU blocks
// Make sure if you change this value that it does not overlap with the
// autogenerated examples.
func testAccTpuNode_tpuNodeTensorFlow(nodeId, tensorFlowVer string) string {
func testAccTpuNode_tpuNodeTensorFlow(nodeId string, versionIdx int) string {
return fmt.Sprintf(`
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "%s"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "%s"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[%d]}"
cidr_block = "10.1.0.0/29"
}
`, nodeId, tensorFlowVer)
`, nodeId, versionIdx)
}

View File

@ -423,3 +423,27 @@ func serviceAccountFQN(serviceAccount string, d TerraformResourceData, config *C
return fmt.Sprintf("projects/-/serviceAccounts/%s@%s.iam.gserviceaccount.com", serviceAccount, project), nil
}
func paginatedListRequest(baseUrl string, config *Config, flattener func(map[string]interface{}) []interface{}) ([]interface{}, error) {
res, err := sendRequest(config, "GET", baseUrl, nil)
if err != nil {
return nil, err
}
ls := flattener(res)
pageToken, ok := res["pageToken"]
for ok {
if pageToken.(string) == "" {
break
}
url := fmt.Sprintf("%s?pageToken=%s", baseUrl, pageToken.(string))
res, err = sendRequest(config, "GET", url, nil)
if err != nil {
return nil, err
}
ls = append(ls, flattener(res))
pageToken, ok = res["pageToken"]
}
return ls, nil
}

View File

@ -39,12 +39,14 @@ To get more information about Node, see:
```hcl
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
cidr_block = "10.2.0.0/29"
}
```
@ -62,6 +64,8 @@ resource "google_compute_network" "tpu_network" {
auto_create_subnetworks = false
}
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu"
zone = "us-central1-b"
@ -69,7 +73,7 @@ resource "google_tpu_node" "tpu" {
accelerator_type = "v3-8"
cidr_block = "10.3.0.0/29"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
description = "Terraform Google Provider test TPU"
network = "${google_compute_network.tpu_network.name}"