Skip to content

processors/processing_processor_mnn_string.cpp

Namespaces

Name
sgns
sgns::sgprocessing

Source code

#include "processors/processing_processor_mnn_string.hpp"
#include <functional>
#include <thread>
#include <sstream>
#include <algorithm>
#include <cstring>
#include <openssl/sha.h> // For SHA256_DIGEST_LENGTH
#include "util/sha256.hpp"

namespace sgns::sgprocessing
{
    using namespace MNN;

    namespace
    {
        bool TryParseTokenIds( const std::string &text, std::vector<int32_t> &tokenIds )
        {
            std::istringstream stream( text );
            int64_t            value = 0;
            while ( stream >> value )
            {
                tokenIds.push_back( static_cast<int32_t>( value ) );
            }

            if ( tokenIds.empty() )
            {
                return false;
            }

            stream.clear();
            stream >> std::ws;
            return stream.eof();
        }

        std::string ToLowerAscii( std::string value )
        {
            std::transform( value.begin(), value.end(), value.begin(),
                            []( unsigned char c ) { return static_cast<char>( std::tolower( c ) ); } );
            return value;
        }
    }

    ProcessingResult MNN_String::StartProcessing( std::vector<std::vector<uint8_t>> &chunkhashes,
                                                   const sgns::IoDeclaration         &proc,
                                                   std::vector<char>                 &textData,
                                                   std::vector<char>                 &modelFile,
                                                   const std::vector<sgns::Parameter> *parameters )
    {
        (void)parameters;
        std::vector<uint8_t> modelFile_bytes;
        modelFile_bytes.assign(modelFile.begin(), modelFile.end());

        std::vector<uint8_t> subTaskResultHash(SHA256_DIGEST_LENGTH);

        // Convert text data to string
        std::string inputText( textData.begin(), textData.end() );

        m_logger->info( "Processing text input: {}", inputText );

        // For string inputs, we process as a single "chunk"
        m_progress = 0.0f;

        std::vector<uint8_t> shahash( SHA256_DIGEST_LENGTH );

        // Default max length (could be extracted from parameters)
        int maxLength = 128;

        std::vector<int32_t> tokenIds;
        bool                 parsedTokenIds = TryParseTokenIds( inputText, tokenIds );
        if ( parsedTokenIds )
        {
            if ( static_cast<int>( tokenIds.size() ) > maxLength )
            {
                tokenIds.resize( maxLength );
            }
            m_logger->info( "Parsed {} token id(s) from input", tokenIds.size() );
        }
        else
        {
            m_logger->info( "Input is not token ids; using character codes as fallback" );
            tokenIds.reserve( std::min( static_cast<int>( inputText.size() ), maxLength ) );
            for ( size_t i = 0; i < inputText.size() && static_cast<int>( i ) < maxLength; ++i )
            {
                tokenIds.push_back( static_cast<int32_t>( static_cast<unsigned char>( inputText[i] ) ) );
            }
        }

        auto procresults = Process( tokenIds, modelFile_bytes, maxLength );

        const float *data     = procresults->host<float>();
        size_t       dataSize = procresults->elementSize() * sizeof( float );
        {
            std::ostringstream sample;
            size_t sampleCount = std::min<size_t>( 16, procresults->elementSize() );
            sample << "Output sample (first " << sampleCount << "): ";
            for ( size_t i = 0; i < sampleCount; ++i )
            {
                if ( i > 0 )
                {
                    sample << ", ";
                }
                sample << data[i];
            }
            m_logger->info( "{}", sample.str() );
        }
        shahash               = sgprocmanagersha::sha256( data, dataSize );
        std::string hashString( shahash.begin(), shahash.end() );
        chunkhashes.push_back( shahash );

        std::string combinedHash = std::string(subTaskResultHash.begin(), subTaskResultHash.end()) + hashString;
        subTaskResultHash = sgprocmanagersha::sha256( combinedHash.c_str(), combinedHash.length() );

        m_progress = 100.0f;

        m_logger->info( "String processing complete" );

        ProcessingResult result;
        result.hash = subTaskResultHash;

        if ( procresults && procresults->elementSize() > 0 )
        {
            const size_t byteCount = procresults->elementSize() * sizeof( float );
            std::vector<char> outputBytes( byteCount );
            std::memcpy( outputBytes.data(), data, byteCount );

            result.output_buffers = std::make_shared<std::pair<std::vector<std::string>, std::vector<std::vector<char>>>>();
            result.output_buffers->first.push_back( "" );
            result.output_buffers->second.push_back( std::move( outputBytes ) );
        }

        return result;
    }

    std::unique_ptr<MNN::Tensor> MNN_String::Process( const std::vector<int32_t> &tokenIds, 
                                                        std::vector<uint8_t> &modelFile,
                                                        const int maxLength ) 
    {
        m_logger->info( "Creating MNN interpreter from model file" );

        // Create MNN interpreter from model bytes
        auto interpreter = std::shared_ptr<MNN::Interpreter>(
            MNN::Interpreter::createFromBuffer(modelFile.data(), modelFile.size())
        );

        if (!interpreter) {
            m_logger->error( "Failed to create MNN interpreter" );
            return std::make_unique<MNN::Tensor>();
        }

        // Configure session
        MNN::ScheduleConfig config;
        config.type = MNN_FORWARD_VULKAN;  // Use Vulkan backend as requested
        config.numThread = 4;

        auto session = interpreter->createSession(config);
        if (!session) {
            m_logger->error( "Failed to create MNN session" );
            return std::make_unique<MNN::Tensor>();
        }

        // Get all input tensors - BERT models typically have multiple inputs
        auto inputTensors = interpreter->getSessionInputAll(session);
        m_logger->info( "Model has {} input tensor(s)", inputTensors.size() );

        // Log all input tensor names and shapes
        for (const auto& inputPair : inputTensors) {
            m_logger->info( "Input '{}': shape {}x{}x{}x{}", 
                           inputPair.first,
                           inputPair.second->batch(), 
                           inputPair.second->channel(),
                           inputPair.second->height(), 
                           inputPair.second->width() );
        }

        // Get the first input tensor (usually input_ids for BERT)
        auto inputTensor = interpreter->getSessionInput(session, nullptr);
        if (!inputTensor) {
            m_logger->error( "Failed to get input tensor" );
            return std::make_unique<MNN::Tensor>();
        }

        // Resize all input tensors to [batch=1, sequence_length=maxLength]
        // BERT models expect: input_ids, attention_mask, token_type_ids (all same shape)
        for (const auto& inputPair : inputTensors) {
            auto tensor = inputPair.second;
            if (tensor->elementSize() <= 4) {
                m_logger->info( "Resizing '{}' to [1, {}]", inputPair.first, maxLength );
                interpreter->resizeTensor( tensor, { 1, maxLength } );
            }
        }
        interpreter->resizeSession( session );

        // Log shapes after resize
        for (const auto& inputPair : inputTensors) {
            m_logger->info( "After resize '{}': shape {}x{}x{}x{}", 
                           inputPair.first,
                           inputPair.second->batch(), 
                           inputPair.second->channel(),
                           inputPair.second->height(), 
                           inputPair.second->width() );
        }

        // Fill input tensors with dummy data
        // In production, you would use a proper BERT tokenizer
        for (const auto& inputPair : inputTensors) {
            auto tensor = inputPair.second;
            MNN::Tensor inputTensorUser(tensor, tensor->getDimensionType());

            auto inputData = inputTensorUser.host<int32_t>();
            for (int i = 0; i < inputTensorUser.elementSize(); ++i) {
                std::string inputName = ToLowerAscii( inputPair.first );
                if ( inputName.find( "input_ids" ) != std::string::npos )
                {
                    inputData[i] = ( i < static_cast<int>( tokenIds.size() ) ) ? tokenIds[i] : 0;
                }
                else if ( inputName.find( "attention_mask" ) != std::string::npos )
                {
                    inputData[i] = ( i < static_cast<int>( tokenIds.size() ) ) ? 1 : 0;
                }
                else if ( inputName.find( "token_type_ids" ) != std::string::npos )
                {
                    inputData[i] = 0;
                }
                else
                {
                    inputData[i] = ( i < static_cast<int>( tokenIds.size() ) ) ? tokenIds[i] : 0;
                }
            }

            tensor->copyFromHostTensor(&inputTensorUser);
            m_logger->info( "Filled '{}' with {} elements", inputPair.first, inputTensorUser.elementSize() );
        }

        // Run inference
        m_logger->info( "Running MNN inference" );
        interpreter->runSession(session);

        // Get output tensor
        auto outputTensor = interpreter->getSessionOutput(session, nullptr);
        if (!outputTensor) {
            m_logger->error( "Failed to get output tensor" );
            return std::make_unique<MNN::Tensor>();
        }

        m_logger->info( "Output tensor shape: {}x{}x{}x{}", 
                       outputTensor->batch(), 
                       outputTensor->channel(),
                       outputTensor->height(), 
                       outputTensor->width() );

        // Create host tensor for output
        auto outputHost = std::make_unique<MNN::Tensor>(outputTensor, outputTensor->getDimensionType());
        outputTensor->copyToHostTensor(outputHost.get());

        m_logger->info( "MNN inference complete" );

        return outputHost;
    }
}

Updated on 2026-04-13 at 23:22:46 -0700