package mssql

import (
	"fmt"
	"log"
	"time"

	"github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/v3.0/sql"
	"github.com/hashicorp/terraform-plugin-sdk/helper/schema"
	"github.com/hashicorp/terraform-plugin-sdk/helper/validation"
	"github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/azure"
	"github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/validate"
	"github.com/terraform-providers/terraform-provider-azurerm/azurerm/internal/clients"
	"github.com/terraform-providers/terraform-provider-azurerm/azurerm/internal/timeouts"
	"github.com/terraform-providers/terraform-provider-azurerm/azurerm/utils"
)

func resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaseline() *schema.Resource {
	return &schema.Resource{
		Create: resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate,
		Read:   resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineRead,
		Update: resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate,
		Delete: resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineDelete,
		Importer: &schema.ResourceImporter{
			State: schema.ImportStatePassthrough,
		},

		Timeouts: &schema.ResourceTimeout{
			Create: schema.DefaultTimeout(30 * time.Minute),
			Read:   schema.DefaultTimeout(5 * time.Minute),
			Update: schema.DefaultTimeout(30 * time.Minute),
			Delete: schema.DefaultTimeout(30 * time.Minute),
		},

		Schema: map[string]*schema.Schema{
			"server_vulnerability_assessment_id": {
				Type:         schema.TypeString,
				Required:     true,
				ForceNew:     true,
				ValidateFunc: azure.ValidateResourceID,
			},

			"database_name": {
				Type:         schema.TypeString,
				Required:     true,
				ForceNew:     true,
				ValidateFunc: azure.ValidateMsSqlDatabaseName,
			},

			"rule_id": {
				Type:     schema.TypeString,
				Required: true,
				ForceNew: true,
			},

			"baseline_name": {
				Type:     schema.TypeString,
				Optional: true,
				ForceNew: true,
				Default:  string(sql.VulnerabilityAssessmentPolicyBaselineNameDefault),
				ValidateFunc: validation.StringInSlice([]string{
					string(sql.VulnerabilityAssessmentPolicyBaselineNameDefault),
					string(sql.VulnerabilityAssessmentPolicyBaselineNameMaster),
				}, false),
			},

			"baseline_result": {
				Type:     schema.TypeSet,
				Required: true,
				Elem: &schema.Resource{
					Schema: map[string]*schema.Schema{
						"result": {
							Type:     schema.TypeList,
							Required: true,
							Elem: &schema.Schema{
								Type:         schema.TypeString,
								ValidateFunc: validate.NoEmptyStrings,
							},
						},
					},
				},
			},
		},
	}
}

func resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate(d *schema.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	ctx, cancel := timeouts.ForCreateUpdate(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] preparing arguments for Azure ARM Vulnerability Assessment Rule Baselines creation.")

	vulnerabilityAssessmentId, err := azure.ParseAzureResourceID(d.Get("server_vulnerability_assessment_id").(string))
	if err != nil {
		return err
	}

	resourceGroupName := vulnerabilityAssessmentId.ResourceGroup
	serverName := vulnerabilityAssessmentId.Path["servers"]
	vulnerabilityAssessmentClient := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient

	vulnerabilityAssessment, err := vulnerabilityAssessmentClient.Get(ctx, resourceGroupName, serverName)
	if err != nil {
		return fmt.Errorf("Error retrieving Server Vulnerability Assessment Settings: %+v", err)
	}
	if vulnerabilityAssessment.StorageContainerPath == nil {
		return fmt.Errorf("Storage Container Path not set in Server Vulnerability Assessment Settings")
	}

	databaseName := d.Get("database_name").(string)
	ruleId := d.Get("rule_id").(string)
	baselineName := sql.VulnerabilityAssessmentPolicyBaselineName(d.Get("baseline_name").(string))
	parameters := expandBaselineResults(d.Get("baseline_result").(*schema.Set))

	result, err := client.CreateOrUpdate(ctx, resourceGroupName, serverName, databaseName, ruleId, baselineName, *parameters)
	if err != nil {
		return fmt.Errorf("error updataing database vulnerability assessment rule baseline: %s", err)
	}

	d.SetId(*result.ID)

	return resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineRead(d, meta)
}

func resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineRead(d *schema.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] Reading Azure ARM Vulnerability Assessment Rule Baselines.")

	id, err := azure.ParseAzureResourceID(d.Id())
	if err != nil {
		return err
	}

	resourceGroupName := id.ResourceGroup
	serverName := id.Path["servers"]
	databaseName := id.Path["databases"]
	ruleId := id.Path["rules"]
	baselineName := sql.VulnerabilityAssessmentPolicyBaselineName(id.Path["baselines"])

	result, err := client.Get(ctx, resourceGroupName, serverName, databaseName, ruleId, baselineName)
	if err != nil {
		if utils.ResponseWasNotFound(result.Response) {
			log.Printf("[WARN] Vulnerability Assessment Rule Baseline %s not found", id)
			d.SetId("")
			return nil
		}

		return fmt.Errorf("error making Read request to Vulnerability Assessment Rule Baselines: %+v", err)
	}

	vulnerabilityAssessmentClient := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	vulnerabilityAssessment, err := vulnerabilityAssessmentClient.Get(ctx, resourceGroupName, serverName)
	if err != nil {
		return fmt.Errorf("Error retrieving Server Vulnerability Assessment Settings: %+v", err)
	}
	d.Set("server_vulnerability_assessment_id", vulnerabilityAssessment.ID)

	d.Set("database_name", databaseName)
	d.Set("rule_id", ruleId)
	d.Set("baseline_name", baselineName)

	if baselineResults := result.BaselineResults; baselineResults != nil {
		d.Set("baseline_result", flattenBaselineResult(baselineResults))
	}

	return nil
}

func resourceArmMssqlDatabaseVulnerabilityAssessmentRuleBaselineDelete(d *schema.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	ctx, cancel := timeouts.ForDelete(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] Deleting Azure ARM Vulnerability Assessment Rule Baselines.")

	id, err := azure.ParseAzureResourceID(d.Id())
	if err != nil {
		return err
	}

	resourceGroupName := id.ResourceGroup
	serverName := id.Path["servers"]
	databaseName := id.Path["databases"]
	ruleId := id.Path["rules"]
	baselineName := sql.VulnerabilityAssessmentPolicyBaselineName(id.Path["baselines"])

	result, err := client.Delete(ctx, resourceGroupName, serverName, databaseName, ruleId, baselineName)
	if err != nil {
		if utils.ResponseWasNotFound(result) {
			log.Printf("[DEBUG] Vulnerability Assessment Rule Baseline %s not found", id)
			return nil
		}

		return fmt.Errorf("error deleting Vulnerability Assessment Rule Baselines: %s", err)
	}

	return nil
}

func expandBaselineResults(baselineResult *schema.Set) *sql.DatabaseVulnerabilityAssessmentRuleBaseline {
	baselineResultList := baselineResult.List()

	baselineResults := make([]sql.DatabaseVulnerabilityAssessmentRuleBaselineItem, len(baselineResultList))

	for i, baselineResult := range baselineResultList {
		result := make([]string, 0)
		baselineResultMap := baselineResult.(map[string]interface{})

		for _, s := range baselineResultMap["result"].([]interface{}) {
			result = append(result, s.(string))
		}

		baselineResults[i] = sql.DatabaseVulnerabilityAssessmentRuleBaselineItem{
			Result: &result,
		}
	}

	return &sql.DatabaseVulnerabilityAssessmentRuleBaseline{
		DatabaseVulnerabilityAssessmentRuleBaselineProperties: &sql.DatabaseVulnerabilityAssessmentRuleBaselineProperties{
			BaselineResults: &baselineResults,
		},
	}
}

func flattenBaselineResult(baselineResults *[]sql.DatabaseVulnerabilityAssessmentRuleBaselineItem) []map[string]interface{} {
	resp := make([]map[string]interface{}, 0)

	if baselineResults != nil {
		for _, baselineResult := range *baselineResults {
			output := map[string]interface{}{}

			if result := baselineResult.Result; result != nil {
				output["result"] = *result
			}

			resp = append(resp, output)
		}
	}

	return resp
}
