// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "llamaembproxy.h"

#include "llama.h"
#include "common/common.h"

#include "memory.h"

GLOBAL_USE_NAMESPACE

// copy from llama.cpp/examples/embedding/embedding.cpp
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
    for (size_t i = 0; i < tokens.size(); i++) {
        llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
    }
}

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
    // clear previous kv_cache values (irrelevant for embeddings)
    llama_kv_cache_clear(ctx);

    // run model
    fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
    if (llama_decode(ctx, batch) < 0) {
        fprintf(stderr, "%s : failed to decode\n", __func__);
    }

    for (int i = 0; i < batch.n_tokens; i++) {
        if (!batch.logits[i]) {
            continue;
        }

        // try to get sequence embeddings - supported only when pooling_type is not NONE
        const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
        if (embd == NULL) {
            embd = llama_get_embeddings_ith(ctx, i);
            if (embd == NULL) {
                fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
                continue;
            }
        }

        float * out = output + batch.seq_id[i][0] * n_embd;
        llama_embd_normalize(embd, out, n_embd);
    }
}

LlamaEmbProxy::LlamaEmbProxy(const std::string &name)
    : EmbeddingProxy()
    , LlamaModelWrapper()
    , modelName(name)
{

}

std::string LlamaEmbProxy::name() const
{
    return modelName;
}

std::list<std::vector<int32_t>> LlamaEmbProxy::tokenize(const std::list<std::string> &prompt, const std::map<std::string, std::string> &params)
{
    std::list<std::vector<int32_t>> ret;
    // embedding token do not parse special
    for (const std::string & pmpt : prompt)
        ret.push_back(llama_tokenize(gModel, pmpt, true, false));
    return ret;
}

std::list<std::vector<float>> LlamaEmbProxy::embedding(const std::list<std::vector<int32_t>> &tokens, const std::map<std::string, std::string> &params)
{
    std::list<std::vector<float>> ret;
    const uint64_t n_batch = gParams->n_batch;
    const uint32_t nCtx = llama_n_ctx(gCtx);

    llama_kv_cache_clear(gCtx);

    std::list<std::vector<int32_t>> inputs = tokens;
    for (auto it = inputs.begin(); it != inputs.end(); ++it) {
        if (it->size() > n_batch) {
            std::cerr << QString("error: number of tokens in input line %0 exceeds batch size %1")
                         .arg(it->size()).arg(n_batch).toStdString() << std::endl;
            return ret;
        }

        if (it->size() > nCtx) {
            std::cerr << QString("error: number of tokens in input line %0 exceeds model context size %1")
                         .arg(it->size()).arg(nCtx).toStdString() << std::endl;
            return ret;
        }

        if (it->empty() || it->back() != llama_token_sep(gModel))
            it->push_back(llama_token_sep(gModel));
    }

    const int n_prompts = inputs.size();
    llama_batch batch = llama_batch_init(n_batch, 0, 1);

    // allocate output
    const int n_embd = llama_n_embd(gModel);
    std::vector<float> embeddings(n_prompts * n_embd, 0);
    float *emb = embeddings.data();
    int k = 0;
    for (auto it = inputs.begin(); it != inputs.end(); ++it) {
        llama_batch_clear(batch);
        batch_add_seq(batch, *it, 0);
        float *out = emb + k * n_embd;
        batch_decode(gCtx, batch, out, 0, n_embd);
        k++;
    }
#if 0
    // break into batches
    int p = 0; // number of prompts processed already
    int s = 0; // number of prompts in current batch
    for (int k = 0; k < n_prompts; k++) {
        // clamp to n_batch tokens
        auto & inp = tokens[k];

        const uint64_t n_toks = inp.size();

        // encode if at capacity
        if (batch.n_tokens + n_toks > n_batch) {
            float * out = emb + p * n_embd;
            batch_decode(gCtx, batch, out, s, n_embd);
            llama_batch_clear(batch);
            p += s;
            s = 0;
        }

        // add to batch
        batch_add_seq(batch, inp, s);
        s += 1;
    }

    // final batch
    float * out = emb + p * n_embd;
    batch_decode(gCtx, batch, out, s, n_embd);
#endif
    const int stride = n_embd * sizeof(float);
    for (int i = 0; i < n_prompts; ++i) {
        std::vector<float> tmp;
        tmp.resize(n_embd);
        memcpy(tmp.data(), (char *)emb + i * stride, stride);
        ret.push_back(tmp);
    }

    llama_batch_free(batch);
    return ret;
}

bool LlamaEmbProxy::initialize(const QString &bin, const QVariantHash &params)
{
    QVariantHash pcp = params;
    pcp.insert("--embedding", "");
    const int nctx = params.value("--ctx-size").toInt();
    if (!pcp.contains("--batch-size") && nctx > 0)
        pcp.insert("--batch-size", nctx);

    if (!pcp.contains("--ubatch-size") && nctx > 0)
        pcp.insert("--ubatch-size", nctx);

    bool ret = LlamaModelWrapper::initialize(bin, pcp);
    return ret;
}
