/*
 * This file is part of LibEuFin.
 * Copyright (C) 2024 Taler Systems S.A.

 * LibEuFin is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation; either version 3, or
 * (at your option) any later version.

 * LibEuFin is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.

 * You should have received a copy of the GNU Affero General Public
 * License along with LibEuFin; see the file COPYING.  If not, see
 * <http://www.gnu.org/licenses/>
 */

import io.ktor.client.request.*
import io.ktor.http.*
import org.junit.Test
import org.postgresql.jdbc.PgConnection
import tech.libeufin.bank.*
import tech.libeufin.common.*
import tech.libeufin.common.crypto.PwCrypto
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneId
import java.util.*
import kotlin.math.max
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.time.DurationUnit
import kotlin.time.measureTime
import kotlin.time.toDuration

class Bench {

    /** Generate [amount] rows to fill the database */
    fun genData(conn: PgConnection, amount: Int) {
        val amount = max(amount, 10)

        // Skip 4 accounts created by bankSetup
        val skipAccount = 4 
        // Customer account will be used in tests so we want to generate more data for him
        val customerAccount = 3
        val exchangeAccount = 2
        // In general half of the data is for generated account and half is for customer
        val mid = amount / 2

        val copyManager = conn.getCopyAPI()
        val password = PwCrypto.hashpw("password")
        fun gen(table: String, lambda: (Int) -> String) {
            println("Gen rows for $table")
            val full = buildString(150*amount) {
                repeat(amount) { 
                    append(lambda(it+1))
                }
            }
            copyManager.copyIn("COPY $table FROM STDIN", full.reader())
        }

        val token32 = ByteArray(32)
        val token64 = ByteArray(64)

        gen("customers(login, name, password_hash, cashout_payto)") {
            "account_$it\t$password\tMr n°$it\t$unknownPayto\n"
        }
        gen("bank_accounts(internal_payto_uri, owning_customer_id, is_public)") {
            "payto://x-taler-bank/localhost/account_$it\t${it+skipAccount}\t${it%3==0}\n"
        }
        gen("bearer_tokens(content, creation_time, expiration_time, scope, is_refreshable, bank_customer, description, last_access)") {
            val account = if (it > mid) customerAccount else it+4
            val hex = token32.rand().encodeHex()
            "\\\\x$hex\t0\t0\treadonly\tfalse\t$account\t\\N\t0\n"
        }
        gen("bank_account_transactions(creditor_payto_uri, creditor_name, debtor_payto_uri, debtor_name, subject, amount, transaction_date, direction, bank_account_id)") {
            val account = if (it > mid) customerAccount else it+4
            "creditor_payto\tcreditor_name\tdebtor_payto\tdebtor_name\tsubject\t(42,0)\t0\tcredit\t$exchangeAccount\n" + 
            "creditor_payto\tcreditor_name\tdebtor_payto\tdebtor_name\tsubject\t(42,0)\t0\tdebit\t$account\n"
        }
        gen("bank_transaction_operations") {
            val hex = token32.rand().encodeHex()
            "\\\\x$hex\t$it\n"
        }
        gen("tan_challenges(body, op, code, creation_date, expiration_date, retry_counter, customer)") {
            val account = if (it > mid) customerAccount else it+4
            "body\taccount_reconfig\tcode\t0\t0\t0\t$account\n"
        }
        gen("taler_withdrawal_operations(withdrawal_uuid, wallet_bank_account, reserve_pub, creation_date)") {
            val account = if (it > mid) customerAccount else it+4
            val hex = token32.rand().encodeHex()
            val uuid = UUID.randomUUID()
            "$uuid\t$account\t\\\\x$hex\t0\n"
        }
        gen("taler_exchange_outgoing(wtid, request_uid, exchange_base_url, bank_transaction, creditor_account_id)") {
            val hex32 = token32.rand().encodeHex()
            val hex64 = token64.rand().encodeHex()
            "\\\\x$hex32\t\\\\x$hex64\turl\t${it*2-1}\t$it\n"
        }
        gen("taler_exchange_incoming(reserve_pub, bank_transaction)") {
            val hex = token32.rand().encodeHex()
            "\\\\x$hex\t${it*2}\n"
        }
        gen("bank_stats(timeframe, start_time)") {
            val instant = Instant.ofEpochSecond(it.toLong())
            val date = LocalDateTime.ofInstant(instant, ZoneId.of("UTC"))
            "day\t$date\n"
        }
        gen("cashout_operations(request_uid,amount_debit,amount_credit,subject,creation_time,bank_account,local_transaction)") {
            val account = if (it > mid) customerAccount else it+4
            val hex = token32.rand().encodeHex()
            "\\\\x$hex\t(0,0)\t(0,0)\tsubject\t0\t$account\t$it\n"
        }
        
        // Update database statistics for better perf
        conn.execSQLUpdate("VACUUM ANALYZE");
    }

    @Test
    fun benchDb() {
        val ITER = System.getenv("BENCH_ITER")?.toIntOrNull() ?: 0
        val AMOUNT = System.getenv("BENCH_AMOUNT")?.toIntOrNull() ?: 0

        if (ITER == 0) {
            println("Skip benchmark, missing BENCH_ITER")
            return
        }
        println("Bench $ITER times with $AMOUNT rows")

        val WARN = 4.toDuration(DurationUnit.MILLISECONDS)
        val ERR = 50.toDuration(DurationUnit.MILLISECONDS)
        

        suspend fun fmtMeasures(times: LongArray): List<String> {
            val min: Long = times.min()
            val max: Long = times.max()
            val mean: Long = times.average().toLong()
            val variance = times.map { (it.toDouble() - mean).pow(2) }.average()
            val stdVar: Long = sqrt(variance.toDouble()).toLong()
            return sequenceOf(min, mean, max, stdVar).map {
                val duration = it.toDuration(DurationUnit.MICROSECONDS)
                val str = duration.toString()
                if (duration > ERR) {
                    ANSI.red(str)
                } else if (duration > WARN) {
                    ANSI.yellow(str)
                } else {
                    ANSI.green(str)
                }
                
            }.toList()
        }
    
        val measures: MutableList<List<String>> = mutableListOf()
    
        suspend fun <R> measureAction(name: String, lambda: suspend (Int) -> R): List<R> {
            val results = mutableListOf<R>()
            val times = LongArray(ITER) { idx ->
                measureTime { 
                    val result = lambda(idx)
                    results.add(result)
                }.inWholeMicroseconds
            }
            measures.add(listOf(ANSI.magenta(name)) + fmtMeasures(times))
            return results
        }
    
        bankSetup { db ->
            // Prepare custoemr accounts
            fillCashoutInfo("customer")
            setMaxDebt("customer", "KUDOS:1000000")

            // Generate data
            db.conn { genData(it, AMOUNT) }

            // Accounts
            val paytos = measureAction("account_create") {
                client.post("/accounts") {
                    json {
                        "username" to "account_bench_$it"
                        "password" to "password"
                        "name" to "Bench Account $it"
                    }
                }.assertOkJson<RegisterAccountResponse>().internal_payto_uri
            }
            measureAction("account_reconfig") {
                client.patch("/accounts/account_bench_$it") {
                    basicAuth("account_bench_$it", "password")
                    json {
                        "name" to "New Bench Account $it"
                    }
                }.assertNoContent()
            }
            measureAction("account_reconfig_auth") {
                client.patch("/accounts/account_bench_$it/auth") {
                    basicAuth("account_bench_$it", "password")
                    json {
                        "old_password" to "password"
                        "new_password" to "password"
                    }
                }.assertNoContent()
            }
            measureAction("account_list") {
                client.get("/accounts") {
                    basicAuth("admin", "admin-password")
                }.assertOk()
            }
            measureAction("account_list_public") {
                client.get("/public-accounts").assertOk()
            }
            measureAction("account_get") {
                client.get("/accounts/account_bench_$it") {
                    basicAuth("account_bench_$it", "password")
                }.assertOk()
            }
            measureAction("account_delete") {
                client.delete("/accounts/account_bench_$it") {
                    basicAuth("account_bench_$it", "password")
                }.assertNoContent()
            }
            
            // Tokens
            val tokens = measureAction("token_create") {
                client.postA("/accounts/customer/token") {
                    json { 
                        "scope" to "readonly"
                        "refreshable" to true
                    }
                }.assertOkJson<TokenSuccessResponse>().access_token
            }
            measureAction("token_refresh") {
                client.post("/accounts/customer/token") {
                    headers[HttpHeaders.Authorization] = "Bearer ${tokens[it]}"
                    json { "scope" to "readonly" }
                }.assertOk()
            }
            measureAction("token_list") {
                client.getA("/accounts/customer/tokens").assertOk()
            }
            measureAction("token_delete") {
                client.delete("/accounts/customer/token") {
                    headers[HttpHeaders.Authorization] = "Bearer ${tokens[it]}"
                }.assertNoContent()
            }
    
            // Transaction
            val transactions = measureAction("transaction_create") {
                client.postA("/accounts/customer/transactions") {
                    json { 
                        "payto_uri" to "$merchantPayto?receiver-name=Test&message=payout"
                        "amount" to "KUDOS:0.0001"
                    }
                }.assertOkJson<TransactionCreateResponse>().row_id
            }
            measureAction("transaction_get") {
                client.getA("/accounts/customer/transactions/${transactions[it]}").assertOk()
            }
            measureAction("transaction_history") {
                client.getA("/accounts/customer/transactions").assertOk()
            }
            measureAction("transaction_revenue") {
                client.getA("/accounts/merchant/taler-revenue/history").assertOk()
            }
    
            // Withdrawal
            val withdrawals = measureAction("withdrawal_create") {
                client.postA("/accounts/customer/withdrawals") {
                    json { 
                        "amount" to "KUDOS:0.0001"
                    }
                }.assertOkJson<BankAccountCreateWithdrawalResponse>().withdrawal_id
            }
            measureAction("withdrawal_get") {
                client.get("/withdrawals/${withdrawals[it]}").assertOk()
            }
            measureAction("withdrawal_status") {
                client.get("/taler-integration/withdrawal-operation/${withdrawals[it]}").assertOk()
            }
            measureAction("withdrawal_select") {
                client.post("/taler-integration/withdrawal-operation/${withdrawals[it]}") {
                    json {
                        "reserve_pub" to EddsaPublicKey.rand()
                        "selected_exchange" to exchangePayto
                    }
                }.assertOk()
            }
            measureAction("withdrawal_confirm") {
                client.postA("/accounts/customer/withdrawals/${withdrawals[it]}/confirm")
                    .assertNoContent()
            }
            measureAction("withdrawal_abort") {
                val uuid = client.postA("/accounts/customer/withdrawals") {
                    json { 
                        "amount" to "KUDOS:0.0001"
                    }
                }.assertOkJson<BankAccountCreateWithdrawalResponse>().withdrawal_id
                client.post("/taler-integration/withdrawal-operation/$uuid/abort")
                    .assertNoContent()
            }
    
            // Cashout
            val converted = convert("KUDOS:0.1")
            val cashouts = measureAction("cashout_create") {
                client.postA("/accounts/customer/cashouts") {
                    json { 
                        "request_uid" to ShortHashCode.rand()
                        "amount_debit" to "KUDOS:0.1"
                        "amount_credit" to convert("KUDOS:0.1")
                    }
                }.assertOkJson<CashoutResponse>().cashout_id
            }
            measureAction("cashout_get") {
                client.getA("/accounts/customer/cashouts/${cashouts[it]}").assertOk()
            }
            measureAction("cashout_history") {
                client.getA("/accounts/customer/cashouts").assertOk()
            }
            measureAction("cashout_history_admin") {
                client.get("/cashouts") {
                    pwAuth("admin")
                }.assertOk()
            }

            // Wire gateway
            measureAction("wg_transfer") {
                client.postA("/accounts/exchange/taler-wire-gateway/transfer") {
                    json { 
                        "request_uid" to HashCode.rand()
                        "amount" to "KUDOS:0.0001"
                        "exchange_base_url" to "http://exchange.example.com/"
                        "wtid" to ShortHashCode.rand()
                        "credit_account" to customerPayto.canonical
                    }
                }.assertOk()
            }
            measureAction("wg_add") {
                client.postA("/accounts/exchange/taler-wire-gateway/admin/add-incoming") {
                    json { 
                        "amount" to "KUDOS:0.0001"
                        "reserve_pub" to EddsaPublicKey.rand()
                        "debit_account" to customerPayto.canonical
                    }
                }.assertOk()
            }
            measureAction("wg_incoming") {
                client.getA("/accounts/exchange/taler-wire-gateway/history/incoming")
                    .assertOk()
            }
            measureAction("wg_outgoing") {
                client.getA("/accounts/exchange/taler-wire-gateway/history/outgoing")
                    .assertOk()
            }

            // TAN challenges
            val challenges = measureAction("tan_send") {
                val id = client.patchA("/accounts/customer") {
                    json { 
                        "contact_data" to obj {
                            "phone" to "+99"
                            "email" to "email@example.com"
                        }
                        "tan_channel" to "sms"
                    }
                }.assertAcceptedJson<TanChallenge>().challenge_id
                val res = client.postA("/accounts/customer/challenge/$id").assertOkJson<TanTransmission>()
                val code = tanCode(res.tan_info)
                Pair(id, code)
            }
            measureAction("tan_send") {
                val (id, code) = challenges[it]
                client.postA("/accounts/customer/challenge/$id/confirm") {
                    json { "tan" to code }
                }.assertNoContent()
            }

            // Other
            measureAction("monitor") {
                val uuid = client.get("/monitor") {
                    pwAuth("admin")
                }.assertOk()
            }
            db.gc.collect(Instant.now(), java.time.Duration.ZERO, java.time.Duration.ZERO, java.time.Duration.ZERO)
            measureAction("gc") {
                db.gc.collect(Instant.now(), java.time.Duration.ZERO, java.time.Duration.ZERO, java.time.Duration.ZERO)
            }
        }
    
        val cols = IntArray(5) { 0 }
    
        printTable(
            listOf("benchmark", "min", "mean", "max", "std").map { ANSI.bold(it) },
            measures,
            ' ',
            listOf(ColumnStyle.DEFAULT) + List(5) { ColumnStyle(false) }
        )
    }
}