#!/usr/bin/python3

import numpy as np
import onnx
from onnx import helper, TensorProto

# step 1: create a dummy neural network with NumPy + ONNX
# Parameters
B = 128
N = 4096

# Random weights for a 2-layer MLP
W1 = np.random.randn(N, N).astype(np.float32)
b1 = np.random.randn(N).astype(np.float32)
W2 = np.random.randn(N, N).astype(np.float32)
b2 = np.random.randn(N).astype(np.float32)

# Create ONNX tensors for initializers
init_W1 = helper.make_tensor('W1', TensorProto.FLOAT, W1.shape, W1.flatten())
init_b1 = helper.make_tensor('b1', TensorProto.FLOAT, b1.shape, b1.flatten())
init_W2 = helper.make_tensor('W2', TensorProto.FLOAT, W2.shape, W2.flatten())
init_b2 = helper.make_tensor('b2', TensorProto.FLOAT, b2.shape, b2.flatten())

# Input and output
input_tensor = helper.make_tensor_value_info('feature', TensorProto.FLOAT, [None, N])
output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, [None, N])

# Nodes
node1 = helper.make_node('Gemm', inputs=['feature', 'W1', 'b1'], outputs=['hidden_linear'])
node2 = helper.make_node('Relu', inputs=['hidden_linear'], outputs=['hidden_relu'])
node3 = helper.make_node('Gemm', inputs=['hidden_relu', 'W2', 'b2'], outputs=['output'])

# Build graph
graph = helper.make_graph(
    nodes=[node1, node2, node3],
    name='DummyLargeMLP',
    inputs=[input_tensor],
    outputs=[output_tensor],
    initializer=[init_W1, init_b1, init_W2, init_b2]
)

# Build model
model = helper.make_model(graph, producer_name='onnx_large_dummy')

filename = 'test_provider.onnx'
print('[..] Exporting dummy neural network for mainly GEMM test')
onnx.save(model, filename)
print('[OK] Exported dummy neural network for mainly GEMM test:', filename)
